Ahmad-01 commited on
Commit
35ffd37
·
verified ·
1 Parent(s): 871d62e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms, models
5
+
6
+ # Load trained model
7
+ checkpoint = torch.load("animal_model.pth", map_location="cpu")
8
+ class_names = checkpoint["class_names"]
9
+
10
+ # Define model architecture
11
+ model = models.resnet50(weights=None) # same as trained
12
+ model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
13
+ model.load_state_dict(checkpoint["model_state_dict"])
14
+ model.eval()
15
+
16
+ # Image preprocessing
17
+ transform = transforms.Compose([
18
+ transforms.Resize((224, 224)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
21
+ std=[0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ # Prediction function
25
+ def predict(image):
26
+ img = Image.fromarray(image).convert("RGB")
27
+ img = transform(img).unsqueeze(0) # add batch dimension
28
+
29
+ with torch.no_grad():
30
+ outputs = model(img)
31
+ _, pred = torch.max(outputs, 1)
32
+
33
+ return class_names[pred.item()]
34
+
35
+ # Gradio Interface
36
+ app = gr.Interface(
37
+ fn=predict,
38
+ inputs=gr.Image(type="numpy"),
39
+ outputs="text",
40
+ title="Animal Image Classifier",
41
+ description="Upload an image of an animal and the model will classify it."
42
+ )
43
+
44
+ if __name__ == "__main__":
45
+ app.launch()