RFTSystems commited on
Commit
64d8db4
·
verified ·
1 Parent(s): 58e8281

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import torchvision.transforms as transforms
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from PIL import Image
@@ -72,9 +73,26 @@ def show_results(input_image: Image.Image):
72
 
73
  return preds, perf_plot, acc_plot, test_acc_text
74
 
75
- # === Gradio Interface Setup ===
 
 
 
 
 
 
 
 
 
 
76
  example_images = []
 
 
 
 
 
 
77
 
 
78
  interface = gr.Interface(
79
  fn=show_results,
80
  inputs=gr.Image(type='pil', label='Upload Image'),
@@ -85,7 +103,7 @@ interface = gr.Interface(
85
  gr.Textbox(label='Final Test Accuracy')
86
  ],
87
  title='CIFAR-10 Image Classification with DCLR Optimizer',
88
- description='Upload an image to see predictions. Training/test plots and accuracy show benchmark results on CIFAR-10.',
89
  examples=example_images
90
  )
91
 
 
1
  import torch
2
  import torchvision.transforms as transforms
3
+ import torchvision
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from PIL import Image
 
73
 
74
  return preds, perf_plot, acc_plot, test_acc_text
75
 
76
+ # === Prepare CIFAR-10 Sample Gallery ===
77
+ # Download CIFAR-10 test set and save a few sample images
78
+ sample_dir = "examples"
79
+ os.makedirs(sample_dir, exist_ok=True)
80
+
81
+ transform_gallery = transforms.Compose([transforms.ToPILImage()])
82
+
83
+ test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
84
+
85
+ # Pick a few samples (car, dog, plane, cat)
86
+ sample_indices = [1, 3, 10, 25] # arbitrary indices
87
  example_images = []
88
+ for idx in sample_indices:
89
+ img, label = test_set[idx]
90
+ pil_img = transform_gallery(img)
91
+ file_path = os.path.join(sample_dir, f"example_{class_labels[label]}.png")
92
+ pil_img.save(file_path)
93
+ example_images.append(file_path)
94
 
95
+ # === Gradio Interface Setup ===
96
  interface = gr.Interface(
97
  fn=show_results,
98
  inputs=gr.Image(type='pil', label='Upload Image'),
 
103
  gr.Textbox(label='Final Test Accuracy')
104
  ],
105
  title='CIFAR-10 Image Classification with DCLR Optimizer',
106
+ description='Upload an image or try sample CIFAR-10 images. See predictions plus benchmark plots and accuracy.',
107
  examples=example_images
108
  )
109