mawady-uni commited on
Commit
b0bb146
·
verified ·
1 Parent(s): 9e9f1c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import pandas as pd
6
+
7
+ # Load pre-trained models
8
+ resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
9
+ resnet18.eval()
10
+
11
+ convnext_tiny = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT)
12
+ convnext_tiny.eval()
13
+
14
+ # Image preprocessing
15
+ preprocess = transforms.Compose([
16
+ transforms.Resize(256),
17
+ transforms.CenterCrop(224),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
20
+ std=[0.229, 0.224, 0.225]),
21
+ ])
22
+
23
+ # Load class labels
24
+ labels = models.ResNet18_Weights.DEFAULT.meta['categories']
25
+
26
+ # Function to get top-k predictions
27
+ def get_topk(output, k=3):
28
+ probs = torch.nn.functional.softmax(output, dim=1)
29
+ topk_probs, topk_indices = torch.topk(probs, k)
30
+ return [(labels[idx], float(prob) * 100) for idx, prob in zip(topk_indices[0], topk_probs[0])]
31
+
32
+ # Inference function
33
+ def classify_image(image):
34
+ image = Image.fromarray(image)
35
+ input_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
36
+
37
+ # ResNet18 top-3 predictions
38
+ with torch.no_grad():
39
+ resnet_output = resnet18(input_tensor)
40
+ resnet_top3 = get_topk(resnet_output)
41
+
42
+ # ConvNeXt-Tiny top-3 predictions
43
+ with torch.no_grad():
44
+ convnext_output = convnext_tiny(input_tensor)
45
+ convnext_top3 = get_topk(convnext_output)
46
+
47
+ # Create DataFrame for table display
48
+ df = pd.DataFrame({
49
+ "Rank": [1, 2, 3],
50
+ "ResNet-18": [f"{label} ({prob:.2f}%)" for label, prob in resnet_top3],
51
+ "ConvNeXt-Tiny": [f"{label} ({prob:.2f}%)" for label, prob in convnext_top3]
52
+ })
53
+ return df
54
+
55
+ # Gradio interface
56
+ iface = gr.Interface(
57
+ fn=classify_image,
58
+ inputs=gr.Image(type="numpy"),
59
+ outputs=gr.Dataframe(headers=["Rank", "ResNet-18", "ConvNeXt-Tiny"], type="pandas"),
60
+ title="Image Classification Validator",
61
+ description="Upload an AI-generated image to see top-3 predictions from ResNet-18 and ConvNeXt-Tiny with probabilities."
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ iface.launch()