StanKarz commited on
Commit
0c633e0
·
verified ·
1 Parent(s): 92a90ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -4,28 +4,28 @@ import gradio as gr
4
  # Load the model
5
  learn = load_learner('model.pkl')
6
 
7
- # Define the prediction function
8
  labels = learn.dls.vocab
9
 
10
-
11
  def predict(img):
12
  try:
 
13
  img = PILImage.create(img)
 
14
  pred, pred_idx, probs = learn.predict(img)
15
- # Convert boolean keys to strings and ensure probabilities are floats
16
- return {"Bauhaus": float(probs[0]), "Bohemian": float(probs[1]), "Scandinavian": float(probs[2])}
17
  except Exception as e:
18
  # Log the exception and return it as an error message
19
  print(f"An error occurred: {e}")
20
  return {"error": str(e)}
21
 
22
-
23
  # Define the Gradio interface
24
  title = "Interior Design Classifier"
25
  description = "Upload an image of an interior design and get a prediction of the design style."
26
  examples = ['1.jpeg', '2.jpg', '3.jpg']
27
 
28
-
29
  # Set up Gradio interface
30
  interface = gr.Interface(
31
  fn=predict,
 
4
  # Load the model
5
  learn = load_learner('model.pkl')
6
 
7
+ # Fetch the labels dynamically from the model's vocabulary
8
  labels = learn.dls.vocab
9
 
10
+ # Define the prediction function
11
  def predict(img):
12
  try:
13
+ # Create the image object
14
  img = PILImage.create(img)
15
+ # Get predictions from the model
16
  pred, pred_idx, probs = learn.predict(img)
17
+ # Convert probabilities to percentages and construct the output dictionary dynamically
18
+ return {labels[i]: f"{probs[i]*100:.2f}%" for i in range(len(labels))}
19
  except Exception as e:
20
  # Log the exception and return it as an error message
21
  print(f"An error occurred: {e}")
22
  return {"error": str(e)}
23
 
 
24
  # Define the Gradio interface
25
  title = "Interior Design Classifier"
26
  description = "Upload an image of an interior design and get a prediction of the design style."
27
  examples = ['1.jpeg', '2.jpg', '3.jpg']
28
 
 
29
  # Set up Gradio interface
30
  interface = gr.Interface(
31
  fn=predict,