NicolasvonRotz commited on
Commit
e8c97c4
·
1 Parent(s): c02f1ff
Files changed (1) hide show
  1. app.py +14 -21
app.py CHANGED
@@ -1,48 +1,41 @@
1
  import gradio as gr
 
 
 
2
  from fastai.vision.all import load_learner
3
 
4
-
5
  def classify_image_color(img):
6
  learn = load_learner('model-color.pkl')
7
  categories = learn.dls.vocab
8
- _, _, probs = learn.predict(img)
9
- return [(category, float(prob)) for category, prob in zip(categories, probs)]
10
 
11
  def classify_image_shape(img):
12
  learn = load_learner('bricks-model.pkl')
13
  categories = learn.dls.vocab
14
- _, _, probs = learn.predict(img)
15
- return [(category, float(prob)) for category, prob in zip(categories, probs)]
16
 
17
  def classify_image(img):
18
- color_result = dict(classify_image_color(img))
19
- shape_result = dict(classify_image_shape(img))
20
  result = {}
21
  for key in set(color_result.keys()) | set(shape_result.keys()):
22
  result[key] = {"color": color_result.get(key, 0.0), "shape": shape_result.get(key, 0.0)}
23
  return result
24
 
25
  def postprocess(prediction):
26
- result = {}
27
- for key, value in prediction.items():
28
- result[key] = []
29
- for inner_key, inner_value in value.items():
30
- result[key].append((inner_key, inner_value))
31
- result[key].sort(key=lambda x: x[1], reverse=True)
32
- return result
33
-
34
-
35
 
36
  image = gr.inputs.Image(shape=(256, 256))
37
- label = gr.outputs.Label()
38
 
39
  intf = gr.Interface(
40
  fn=classify_image,
41
  inputs=image,
42
- outputs=label,
43
  examples="",
44
  title="Lego Brick Classifier",
45
- layout="vertical",
46
- postprocess=postprocess
47
  )
48
- intf.launch()
 
1
  import gradio as gr
2
+ import requests
3
+ import json
4
+ from gradio.mix import Parallel
5
  from fastai.vision.all import load_learner
6
 
 
7
  def classify_image_color(img):
8
  learn = load_learner('model-color.pkl')
9
  categories = learn.dls.vocab
10
+ pred, idx, probs = learn.predict(img)
11
+ return {category: float(prob) for category, prob in zip(categories, probs)}
12
 
13
  def classify_image_shape(img):
14
  learn = load_learner('bricks-model.pkl')
15
  categories = learn.dls.vocab
16
+ pred, idx, probs = learn.predict(img)
17
+ return {category: float(prob) for category, prob in zip(categories, probs)}
18
 
19
  def classify_image(img):
20
+ color_result = classify_image_color(img)
21
+ shape_result = classify_image_shape(img)
22
  result = {}
23
  for key in set(color_result.keys()) | set(shape_result.keys()):
24
  result[key] = {"color": color_result.get(key, 0.0), "shape": shape_result.get(key, 0.0)}
25
  return result
26
 
27
  def postprocess(prediction):
28
+ return json.dumps(prediction)
 
 
 
 
 
 
 
 
29
 
30
  image = gr.inputs.Image(shape=(256, 256))
31
+ output_json = gr.outputs.Textbox(type="auto", label="JSON Output")
32
 
33
  intf = gr.Interface(
34
  fn=classify_image,
35
  inputs=image,
36
+ outputs=output_json,
37
  examples="",
38
  title="Lego Brick Classifier",
39
+ layout="vertical"
 
40
  )
41
+ intf.launch(share=True)