NicolasvonRotz commited on
Commit
fdcd337
·
1 Parent(s): cb8326d
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from gradio.mix import Parallel
3
  from fastai.vision.all import load_learner
4
 
 
5
  def classify_image_color(img):
6
  from fastai.vision.all import load_learner
7
  learn = load_learner('model-color.pkl')
@@ -9,6 +10,7 @@ def classify_image_color(img):
9
  pred, idx, probs = learn.predict(img)
10
  return {category: float(prob) for category, prob in zip(categories, probs)}
11
 
 
12
  def classify_image_shape(img):
13
  from fastai.vision.all import load_learner
14
  learn = load_learner('bricks-model.pkl')
@@ -16,6 +18,7 @@ def classify_image_shape(img):
16
  pred, idx, probs = learn.predict(img)
17
  return {category: float(prob) for category, prob in zip(categories, probs)}
18
 
 
19
  def classify_image(img):
20
  color_result = classify_image_color(img)
21
  shape_result = classify_image_shape(img)
@@ -24,6 +27,7 @@ def classify_image(img):
24
  result[key] = {"color": color_result.get(key, 0.0), "shape": shape_result.get(key, 0.0)}
25
  return result
26
 
 
27
  def postprocess(prediction):
28
  sorted_pred = sorted(prediction.items(), key=lambda x: x[1], reverse=True)
29
  return sorted_pred
@@ -33,11 +37,12 @@ image = gr.inputs.Image(shape=(256, 256))
33
  label = gr.outputs.Label()
34
 
35
  intf = gr.Interface(
36
- fn=classify_image,
37
- inputs=image,
38
- outputs=label,
39
- examples="",
40
- title="Lego Brick Classifier",
41
- layout="vertical"
 
42
  )
43
  intf.launch()
 
2
  from gradio.mix import Parallel
3
  from fastai.vision.all import load_learner
4
 
5
+
6
  def classify_image_color(img):
7
  from fastai.vision.all import load_learner
8
  learn = load_learner('model-color.pkl')
 
10
  pred, idx, probs = learn.predict(img)
11
  return {category: float(prob) for category, prob in zip(categories, probs)}
12
 
13
+
14
  def classify_image_shape(img):
15
  from fastai.vision.all import load_learner
16
  learn = load_learner('bricks-model.pkl')
 
18
  pred, idx, probs = learn.predict(img)
19
  return {category: float(prob) for category, prob in zip(categories, probs)}
20
 
21
+
22
  def classify_image(img):
23
  color_result = classify_image_color(img)
24
  shape_result = classify_image_shape(img)
 
27
  result[key] = {"color": color_result.get(key, 0.0), "shape": shape_result.get(key, 0.0)}
28
  return result
29
 
30
+
31
  def postprocess(prediction):
32
  sorted_pred = sorted(prediction.items(), key=lambda x: x[1], reverse=True)
33
  return sorted_pred
 
37
  label = gr.outputs.Label()
38
 
39
  intf = gr.Interface(
40
+ fn=classify_image,
41
+ inputs=image,
42
+ outputs=label,
43
+ examples="",
44
+ title="Lego Brick Classifier",
45
+ layout="vertical",
46
+ postprocess=postprocess
47
  )
48
  intf.launch()