srinikesh1432 commited on
Commit
4ac6722
·
verified ·
1 Parent(s): 8883ec5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -25
app.py CHANGED
@@ -1,25 +1,32 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
- from PIL import Image
4
-
5
- # Load an image classification pipeline
6
- classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
7
-
8
- def classify_image(img, top_k=3):
9
- if img is None:
10
- return {"Error": 1.0}
11
- results = classifier(img, top_k=top_k)
12
- # Return as {label: score} for Gradio Label component
13
- return {r["label"]: float(r["score"]) for r in results}
14
-
15
- # Gradio interface
16
- demo = gr.Interface(
17
- fn=classify_image,
18
- inputs=[gr.Image(type="pil", label="Upload Image"), gr.Slider(1, 5, value=3, label="Top K Predictions")],
19
- outputs=gr.Label(num_top_classes=5, label="Predictions"),
20
- title="Image Classification App",
21
- description="Upload an image and the model will predict the top objects in it."
22
- )
23
-
24
- if __name__ == "__main__":
25
- demo.launch()
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+
5
+ # Use ResNet-50 model (1000 common ImageNet categories like dog, cat, car, etc.)
6
+ classifier = pipeline("image-classification", model="microsoft/resnet-50")
7
+
8
+ def classify_image(img, top_k=3):
9
+ """
10
+ Takes an uploaded image, runs classification,
11
+ and returns the top_k labels with confidence scores.
12
+ """
13
+ if img is None:
14
+ return {"Error": 1.0}
15
+
16
+ results = classifier(img, top_k=top_k)
17
+ return {r["label"]: float(r["score"]) for r in results}
18
+
19
+ # Gradio interface
20
+ demo = gr.Interface(
21
+ fn=classify_image,
22
+ inputs=[
23
+ gr.Image(type="pil", label="Upload Image"),
24
+ gr.Slider(1, 5, value=3, step=1, label="Top K Predictions")
25
+ ],
26
+ outputs=gr.Label(num_top_classes=5, label="Predictions"),
27
+ title="Image Classification App",
28
+ description="Upload an image and the model will predict the top objects in it."
29
+ )
30
+
31
+ if __name__ == "__main__":
32
+ demo.launch()