NicolasvonRotz commited on
Commit
67edaec
·
1 Parent(s): abd1e19
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -1,18 +1,41 @@
1
  import gradio as gr
 
2
  from fastai.vision.all import load_learner
3
 
4
  def classify_image_color(img):
5
- learn_color = load_learner('model-color.pkl')
6
- pred,_,probs = learn_color.predict(img)
7
- return dict(zip(learn_color.dls.vocab, map(float,probs)))
 
 
8
 
9
  def classify_image_shape(img):
10
- learn_shape = load_learner('bricks-model.pkl')
11
- pred,_,probs = learn_shape.predict(img)
12
- return dict(zip(learn_shape.dls.vocab, map(float,probs)))
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  image = gr.inputs.Image(shape=(256, 256))
15
- label = gr.outputs.Label(num_top_classes=3)
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- intf = gr.Interface(fn=[classify_image_color, classify_image_shape], inputs=image, outputs=label, examples="", title="Lego Brick Classifier", layout="vertical")
18
- intf.launch(share=True)
 
1
  import gradio as gr
2
+ from gradio.mix import Parallel
3
  from fastai.vision.all import load_learner
4
 
5
  def classify_image_color(img):
6
+ from fastai.vision.all import load_learner
7
+ learn = load_learner('model-color')
8
+ categories = learn.dls.vocab
9
+ pred, idx, probs = learn.predict(img)
10
+ return {category: float(prob) for category, prob in zip(categories, probs)}
11
 
12
  def classify_image_shape(img):
13
+ from fastai.vision.all import load_learner
14
+ learn = load_learner('bricks-model.pkl')
15
+ categories = learn.dls.vocab
16
+ pred, idx, probs = learn.predict(img)
17
+ return {category: float(prob) for category, prob in zip(categories, probs)}
18
+
19
+ def classify_image(img):
20
+ color_result = classify_image_color(img)
21
+ shape_result = classify_image_shape(img)
22
+ result = {}
23
+ for key in set(color_result.keys()) | set(shape_result.keys()):
24
+ result[key] = {"color": color_result.get(key, 0.0), "shape": shape_result.get(key, 0.0)}
25
+ return result
26
+
27
 
28
  image = gr.inputs.Image(shape=(256, 256))
29
+ label = gr.outputs.Label()
30
+
31
+ intf = gr.Interface(
32
+ fn=classify_image,
33
+ inputs=image,
34
+ outputs=label,
35
+ examples="",
36
+ title="Lego Brick Classifier",
37
+ layout="vertical"
38
+ )
39
+ intf.launch()
40
+
41