GenAIDevTOProd commited on
Commit
d97b069
·
verified ·
1 Parent(s): 725e438

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torchvision.models as models
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+
9
+ # CIFAR-10 class names
10
+ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
11
+ 'dog', 'frog', 'horse', 'ship', 'truck']
12
+
13
+ # Load ResNet18 model and adapt final layer for CIFAR-10
14
+ resnet18 = models.resnet18(pretrained=True)
15
+ resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 10) # Replace final layer
16
+ resnet18.load_state_dict(torch.load("/content/sample_data/resnet18_fft_cifar10.pth", map_location=torch.device('cpu')))
17
+ resnet18.eval()
18
+
19
+ # Image transform
20
+ transform = transforms.Compose([
21
+ transforms.Resize((224, 224)), # ResNet18 expects 224x224
22
+ transforms.ToTensor(),
23
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
24
+ ])
25
+
26
+ # FFT Visualizer
27
+ def apply_fft_visualization(image: Image.Image):
28
+ img_np = np.array(image.resize((32, 32))) / 255.0
29
+ fft_images = []
30
+ for i in range(3):
31
+ channel = img_np[:, :, i]
32
+ fft = np.fft.fft2(channel)
33
+ fft_shift = np.fft.fftshift(fft)
34
+ magnitude = np.log1p(np.abs(fft_shift))
35
+ fft_images.append(magnitude)
36
+
37
+ fig, axs = plt.subplots(1, 3, figsize=(12, 4))
38
+ for i in range(3):
39
+ axs[i].imshow(fft_images[i], cmap='inferno')
40
+ axs[i].set_title(['Red', 'Green', 'Blue'][i])
41
+ axs[i].axis('off')
42
+ plt.tight_layout()
43
+ return fig
44
+
45
+ # Prediction Function
46
+ def predict(img: Image.Image, mode="Raw"):
47
+ if mode == "FFT":
48
+ return None, apply_fft_visualization(img)
49
+
50
+ img_tensor = transform(img).unsqueeze(0)
51
+ with torch.no_grad():
52
+ outputs = resnet18(img_tensor)
53
+ _, predicted = torch.max(outputs, 1)
54
+ label = classes[predicted.item()]
55
+ return label, None
56
+
57
+ # Gradio App
58
+ gr.Interface(
59
+ fn=predict,
60
+ inputs=[
61
+ gr.Image(type="pil", label="Upload Image"),
62
+ gr.Radio(["Raw", "FFT"], label="Mode", value="Raw")
63
+ ],
64
+ outputs=[
65
+ gr.Label(label="Prediction"),
66
+ gr.Plot(label="FFT Visualization")
67
+ ],
68
+ title="CIFAR-10 Visual Analyzer (ResNet18)",
69
+ description="Upload an image and choose mode: Raw classification (ResNet18) or visualize FFT of RGB channels.\n\nDisclaimer: This model is trained on CIFAR-10 and works best on low-res, centered images."
70
+ ).launch()