NicolasvonRotz commited on
Commit
b403237
·
1 Parent(s): fdcd337
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -2,32 +2,28 @@ import gradio as gr
2
  from gradio.mix import Parallel
3
  from fastai.vision.all import load_learner
4
 
5
-
6
  def classify_image_color(img):
7
  from fastai.vision.all import load_learner
8
  learn = load_learner('model-color.pkl')
9
  categories = learn.dls.vocab
10
  pred, idx, probs = learn.predict(img)
11
- return {category: float(prob) for category, prob in zip(categories, probs)}
12
-
13
 
14
  def classify_image_shape(img):
15
  from fastai.vision.all import load_learner
16
  learn = load_learner('bricks-model.pkl')
17
  categories = learn.dls.vocab
18
  pred, idx, probs = learn.predict(img)
19
- return {category: float(prob) for category, prob in zip(categories, probs)}
20
-
21
 
22
  def classify_image(img):
23
  color_result = classify_image_color(img)
24
  shape_result = classify_image_shape(img)
25
  result = {}
26
  for key in set(color_result.keys()) | set(shape_result.keys()):
27
- result[key] = {"color": color_result.get(key, 0.0), "shape": shape_result.get(key, 0.0)}
28
  return result
29
 
30
-
31
  def postprocess(prediction):
32
  sorted_pred = sorted(prediction.items(), key=lambda x: x[1], reverse=True)
33
  return sorted_pred
@@ -37,12 +33,11 @@ image = gr.inputs.Image(shape=(256, 256))
37
  label = gr.outputs.Label()
38
 
39
  intf = gr.Interface(
40
- fn=classify_image,
41
- inputs=image,
42
- outputs=label,
43
- examples="",
44
- title="Lego Brick Classifier",
45
- layout="vertical",
46
- postprocess=postprocess
47
  )
48
  intf.launch()
 
2
  from gradio.mix import Parallel
3
  from fastai.vision.all import load_learner
4
 
 
5
  def classify_image_color(img):
6
  from fastai.vision.all import load_learner
7
  learn = load_learner('model-color.pkl')
8
  categories = learn.dls.vocab
9
  pred, idx, probs = learn.predict(img)
10
+ return {f"{category}_color": float(prob) for category, prob in zip(categories, probs)}
 
11
 
12
  def classify_image_shape(img):
13
  from fastai.vision.all import load_learner
14
  learn = load_learner('bricks-model.pkl')
15
  categories = learn.dls.vocab
16
  pred, idx, probs = learn.predict(img)
17
+ return {f"{category}_shape": float(prob) for category, prob in zip(categories, probs)}
 
18
 
19
  def classify_image(img):
20
  color_result = classify_image_color(img)
21
  shape_result = classify_image_shape(img)
22
  result = {}
23
  for key in set(color_result.keys()) | set(shape_result.keys()):
24
+ result[key] = color_result.get(key, 0.0) + shape_result.get(key, 0.0)
25
  return result
26
 
 
27
  def postprocess(prediction):
28
  sorted_pred = sorted(prediction.items(), key=lambda x: x[1], reverse=True)
29
  return sorted_pred
 
33
  label = gr.outputs.Label()
34
 
35
  intf = gr.Interface(
36
+ fn=classify_image,
37
+ inputs=image,
38
+ outputs=label,
39
+ examples="",
40
+ title="Lego Brick Classifier",
41
+ layout="vertical"
 
42
  )
43
  intf.launch()