RFTSystems commited on
Commit
ad84fb8
·
verified ·
1 Parent(s): 64d8db4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -73,24 +73,26 @@ def show_results(input_image: Image.Image):
73
 
74
  return preds, perf_plot, acc_plot, test_acc_text
75
 
76
- # === Prepare CIFAR-10 Sample Gallery ===
77
- # Download CIFAR-10 test set and save a few sample images
78
  sample_dir = "examples"
79
  os.makedirs(sample_dir, exist_ok=True)
80
 
81
  transform_gallery = transforms.Compose([transforms.ToPILImage()])
82
-
83
  test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
84
 
85
- # Pick a few samples (car, dog, plane, cat)
86
- sample_indices = [1, 3, 10, 25] # arbitrary indices
87
  example_images = []
88
- for idx in sample_indices:
 
89
  img, label = test_set[idx]
90
- pil_img = transform_gallery(img)
91
- file_path = os.path.join(sample_dir, f"example_{class_labels[label]}.png")
92
- pil_img.save(file_path)
93
- example_images.append(file_path)
 
 
 
 
94
 
95
  # === Gradio Interface Setup ===
96
  interface = gr.Interface(
 
73
 
74
  return preds, perf_plot, acc_plot, test_acc_text
75
 
76
+ # === Prepare CIFAR-10 Sample Gallery (one per class) ===
 
77
  sample_dir = "examples"
78
  os.makedirs(sample_dir, exist_ok=True)
79
 
80
  transform_gallery = transforms.Compose([transforms.ToPILImage()])
 
81
  test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
82
 
83
+ # Collect one sample for each class
 
84
  example_images = []
85
+ seen_classes = set()
86
+ for idx in range(len(test_set)):
87
  img, label = test_set[idx]
88
+ if label not in seen_classes:
89
+ pil_img = transform_gallery(img)
90
+ file_path = os.path.join(sample_dir, f"example_{class_labels[label]}.png")
91
+ pil_img.save(file_path)
92
+ example_images.append(file_path)
93
+ seen_classes.add(label)
94
+ if len(seen_classes) == 10:
95
+ break
96
 
97
  # === Gradio Interface Setup ===
98
  interface = gr.Interface(