RFTSystems commited on
Commit
fe05156
·
verified ·
1 Parent(s): ad84fb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -45
app.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn.functional as F
6
  from PIL import Image
7
  import gradio as gr
8
  import os
 
9
 
10
  # === Simple CNN Model Definition ===
11
  class SimpleCNN(nn.Module):
@@ -45,7 +46,11 @@ preprocess = transforms.Compose([
45
  transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
46
  ])
47
 
48
- # === Inference Function ===
 
 
 
 
49
  def inference(input_image: Image.Image):
50
  if model.training:
51
  model.eval()
@@ -56,58 +61,53 @@ def inference(input_image: Image.Image):
56
  confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
57
  return confidences
58
 
59
- # === Results Viewer Function ===
60
- def show_results(input_image: Image.Image):
61
- preds = inference(input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # Load plots if they exist
64
  perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None
65
  acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None
66
 
67
- # Load final test accuracy number
68
- test_acc_text = "Final test accuracy not available."
69
- if os.path.exists("final_test_accuracy.txt"):
70
- with open("final_test_accuracy.txt", "r") as f:
71
- test_acc_value = f.read().strip()
72
- test_acc_text = f"Final Test Accuracy: {test_acc_value}%"
73
 
74
- return preds, perf_plot, acc_plot, test_acc_text
75
-
76
- # === Prepare CIFAR-10 Sample Gallery (one per class) ===
77
- sample_dir = "examples"
78
- os.makedirs(sample_dir, exist_ok=True)
79
 
80
- transform_gallery = transforms.Compose([transforms.ToPILImage()])
81
- test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
 
 
82
 
83
- # Collect one sample for each class
84
- example_images = []
85
- seen_classes = set()
86
- for idx in range(len(test_set)):
87
- img, label = test_set[idx]
88
- if label not in seen_classes:
89
- pil_img = transform_gallery(img)
90
- file_path = os.path.join(sample_dir, f"example_{class_labels[label]}.png")
91
- pil_img.save(file_path)
92
- example_images.append(file_path)
93
- seen_classes.add(label)
94
- if len(seen_classes) == 10:
95
- break
96
 
97
- # === Gradio Interface Setup ===
98
- interface = gr.Interface(
99
- fn=show_results,
100
- inputs=gr.Image(type='pil', label='Upload Image'),
101
- outputs=[
102
- gr.Label(num_top_classes=3, label='Predictions'),
103
- gr.Image(type='filepath', label='Training Performance'),
104
- gr.Image(type='filepath', label='Final Test Accuracy Plot'),
105
- gr.Textbox(label='Final Test Accuracy')
106
- ],
107
- title='CIFAR-10 Image Classification with DCLR Optimizer',
108
- description='Upload an image or try sample CIFAR-10 images. See predictions plus benchmark plots and accuracy.',
109
- examples=example_images
110
- )
111
 
112
  if __name__ == '__main__':
113
- interface.launch()
 
6
  from PIL import Image
7
  import gradio as gr
8
  import os
9
+ import numpy as np
10
 
11
  # === Simple CNN Model Definition ===
12
  class SimpleCNN(nn.Module):
 
46
  transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
47
  ])
48
 
49
+ # === CIFAR-10 Test Loader for Benchmark Mode ===
50
+ test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
51
+ test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
52
+
53
+ # === Inference Function (single image) ===
54
  def inference(input_image: Image.Image):
55
  if model.training:
56
  model.eval()
 
61
  confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
62
  return confidences
63
 
64
+ # === Benchmark Mode: Evaluate on full test set ===
65
+ def benchmark():
66
+ model.eval()
67
+ correct = 0
68
+ total = 0
69
+ class_correct = np.zeros(10)
70
+ class_total = np.zeros(10)
71
+
72
+ with torch.no_grad():
73
+ for inputs, labels in test_loader:
74
+ outputs = model(inputs)
75
+ _, predicted = outputs.max(1)
76
+ total += labels.size(0)
77
+ correct += predicted.eq(labels).sum().item()
78
+ c = (predicted == labels).squeeze()
79
+ for i in range(len(labels)):
80
+ label = labels[i].item()
81
+ class_correct[label] += c[i].item()
82
+ class_total[label] += 1
83
+
84
+ overall_acc = 100.0 * correct / total
85
+ classwise_acc = {class_labels[i]: round(100.0 * class_correct[i] / class_total[i], 2) for i in range(10)}
86
 
87
  # Load plots if they exist
88
  perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None
89
  acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None
90
 
91
+ return overall_acc, classwise_acc, perf_plot, acc_plot
 
 
 
 
 
92
 
93
+ # === Gradio Interface Setup ===
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown("## CIFAR-10 Image Classification with DCLR Optimizer")
96
+ gr.Markdown("Upload an image for prediction, or run Benchmark Mode to see full test accuracy.")
 
97
 
98
+ with gr.Tab("Single Image Inference"):
99
+ inp = gr.Image(type='pil', label='Upload Image')
100
+ out = gr.Label(num_top_classes=3, label='Predictions')
101
+ inp.change(fn=inference, inputs=inp, outputs=out)
102
 
103
+ with gr.Tab("Benchmark Mode"):
104
+ btn = gr.Button("Run Benchmark on CIFAR-10 Test Set")
105
+ overall = gr.Textbox(label="Overall Test Accuracy")
106
+ classwise = gr.JSON(label="Per-Class Accuracy (%)")
107
+ perf_plot = gr.Image(type='filepath', label='Training Performance')
108
+ acc_plot = gr.Image(type='filepath', label='Final Test Accuracy Plot')
 
 
 
 
 
 
 
109
 
110
+ btn.click(fn=benchmark, inputs=None, outputs=[overall, classwise, perf_plot, acc_plot])
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  if __name__ == '__main__':
113
+ demo.launch()