Anum15 commited on
Commit
2fec44e
·
verified ·
1 Parent(s): e7f9a79

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTImageProcessor, ViTForImageClassification
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ # Load general ViT model (ImageNet pretrained)
7
+ model_name = "google/vit-base-patch16-224"
8
+ processor = ViTImageProcessor.from_pretrained(model_name)
9
+ model = ViTForImageClassification.from_pretrained(model_name)
10
+
11
+ def predict(image):
12
+ if image is None:
13
+ return "Please upload an image."
14
+
15
+ # Preprocess image
16
+ inputs = processor(images=image, return_tensors="pt")
17
+
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+ logits = outputs.logits
21
+
22
+ probs = torch.nn.functional.softmax(logits, dim=1)
23
+ conf, predicted_class = torch.max(probs, dim=1)
24
+ label = model.config.id2label[predicted_class.item()]
25
+ confidence = conf.item() * 100
26
+
27
+ # This label will be a general ImageNet class, e.g. "banana", "bee", "daisy"
28
+ return f"Detected class: {label}\nConfidence: {confidence:.2f}%"
29
+
30
+ gr.Interface(
31
+ fn=predict,
32
+ inputs=gr.Image(type="pil"),
33
+ outputs="text",
34
+ title="General Image Classification with ViT",
35
+ description="Upload an image to classify using ViT pretrained on ImageNet."
36
+ ).launch()
37
+