jijinAI commited on
Commit
b2cdef1
·
1 Parent(s): 0b1a357

Update app.py to use model repository

Browse files
Files changed (2) hide show
  1. app.py +23 -13
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,25 +1,35 @@
1
- #cell
2
  from fastai.vision.all import *
 
3
  import gradio as gr
 
4
 
5
- learn = load_learner('model.pkl')
6
- labels = learn.dls.vocab
 
 
 
 
7
 
8
  def classify_image(img):
9
  img = PILImage.create(img)
10
- pred,pred_idx,probs = learn.predict(img)
11
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
12
 
13
- # Cell
14
  image = gr.Image()
15
 
16
-
17
- label_list ="<ul>\n"
18
- label_list += "\n".join(["<li>" + str(s) + "</li>" for s in labels])
19
  label_list += "\n</ul>"
20
- description= "Try out this bird classifier by uploading an image of a species listed below and hit submit" + label_list
21
 
 
 
 
 
 
 
 
22
 
23
- intf = gr.Interface(fn=classify_image, title="NZ Endangered Bird Classifier", inputs=image, article=description, outputs=gr.Label(num_top_classes=len(labels)))
24
- intf.launch(inline=False)
25
-
 
 
1
  from fastai.vision.all import *
2
+ from fastai.vision.core import PILImage
3
  import gradio as gr
4
+ from transformers import pipeline
5
 
6
+ # Load the pre-trained model and tokenizer
7
+ model_name = "jijinAI/bird-detection"
8
+ classifier = pipeline("image-classification", model=model_name)
9
+
10
+ # Get labels from the model
11
+ labels = classifier.model.config.id2label
12
 
13
  def classify_image(img):
14
  img = PILImage.create(img)
15
+ predictions = classifier(img)
16
+ return {labels[pred['label']]: float(pred['score']) for pred in predictions}
17
 
18
+ # Create Gradio interface
19
  image = gr.Image()
20
 
21
+ label_list = "<ul>\n"
22
+ label_list += "\n".join(["<li>" + str(s) + "</li>" for s in labels.values()])
 
23
  label_list += "\n</ul>"
24
+ description = "Try out this bird classifier by uploading an image of a species listed below and hit submit" + label_list
25
 
26
+ interface = gr.Interface(
27
+ fn=classify_image,
28
+ inputs=image,
29
+ outputs=gr.outputs.Label(num_top_classes=3),
30
+ title="Bird Species Classifier",
31
+ description=description
32
+ )
33
 
34
+ if __name__ == "__main__":
35
+ interface.launch()
 
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  fastai
 
2
  flask
 
1
  fastai
2
+ transformers
3
  flask