RFTSystems commited on
Commit
ce46e86
·
verified ·
1 Parent(s): d137713

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -36
app.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  from PIL import Image
6
  import gradio as gr
7
- import os # Import os to check for model file
8
 
9
  # === Simple CNN Model Definition ===
10
  class SimpleCNN(nn.Module):
@@ -27,67 +27,66 @@ class SimpleCNN(nn.Module):
27
  model = SimpleCNN()
28
  model_path = 'simple_cnn_dclr_tuned.pth'
29
 
30
- # Check if the model file exists before loading
31
  if os.path.exists(model_path):
32
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
33
- model.eval() # Set model to evaluation mode
34
  print(f"Model loaded successfully from {model_path}")
35
  else:
36
- print(f"Warning: Model file '{model_path}' not found. Please ensure 'train_dclr_model.py' has been run.")
37
- # Optionally, you might want to exit or raise an error if the model is crucial
38
-
39
 
40
  # === CIFAR-10 Class Labels ===
41
- class_labels = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
42
 
43
  # === Image Preprocessing ===
44
  preprocess = transforms.Compose([
45
- transforms.Resize(32), # CIFAR-10 images are 32x32
46
  transforms.ToTensor(),
47
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats are common
48
  ])
49
 
50
  # === Inference Function ===
51
  def inference(input_image: Image.Image):
52
- if model.training: # Ensure model is in eval mode
53
  model.eval()
54
-
55
- # Preprocess the image
56
- processed_image = preprocess(input_image)
57
- # Add a batch dimension
58
- processed_image = processed_image.unsqueeze(0)
59
-
60
- # Perform inference
61
  with torch.no_grad():
62
  outputs = model(processed_image)
63
  probabilities = F.softmax(outputs, dim=1)
64
-
65
- # Convert probabilities to a dictionary of class labels and scores
66
- confidences = {class_labels[i]: float(probabilities[0, i]) for i in range(len(class_labels))}
67
  return confidences
68
 
69
- # === Gradio Interface Setup ===
70
- # Example images (replace with actual paths if available, or keep as dummy for now)
71
- # For a Hugging Face Space, you might place example images in an 'examples/' directory.
72
- example_images = [
73
- # os.path.join(os.path.dirname(__file__), "examples/example_car.png"),
74
- # os.path.join(os.path.dirname(__file__), "examples/example_dog.png"),
75
- # os.path.join(os.path.dirname(__file__), "examples/example_plane.png")
76
- ]
77
 
78
- # A placeholder for example images since we don't have them generated yet.
79
- # Users can upload their own or I will add some placeholder images if needed in the next step.
80
- # For now, an empty list of examples is fine.
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  interface = gr.Interface(
83
- fn=inference,
84
- inputs=gr.Image(type='pil', label='Input Image'),
85
- outputs=gr.Label(num_top_classes=3, label='Predictions'),
 
 
 
 
 
86
  title='CIFAR-10 Image Classification with DCLR Optimizer',
87
- description='Upload an image and see the model\'s predictions using a SimpleCNN trained with the DCLR optimizer.',
88
  examples=example_images
89
  )
90
 
91
- # === Launch Gradio App ===
92
  if __name__ == '__main__':
93
  interface.launch()
 
4
  import torch.nn.functional as F
5
  from PIL import Image
6
  import gradio as gr
7
+ import os
8
 
9
  # === Simple CNN Model Definition ===
10
  class SimpleCNN(nn.Module):
 
27
  model = SimpleCNN()
28
  model_path = 'simple_cnn_dclr_tuned.pth'
29
 
 
30
  if os.path.exists(model_path):
31
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
32
+ model.eval()
33
  print(f"Model loaded successfully from {model_path}")
34
  else:
35
+ print(f"Warning: Model file '{model_path}' not found. Please run train_dclr_model.py first.")
 
 
36
 
37
  # === CIFAR-10 Class Labels ===
38
+ class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
39
 
40
  # === Image Preprocessing ===
41
  preprocess = transforms.Compose([
42
+ transforms.Resize(32),
43
  transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
45
  ])
46
 
47
  # === Inference Function ===
48
  def inference(input_image: Image.Image):
49
+ if model.training:
50
  model.eval()
51
+ processed_image = preprocess(input_image).unsqueeze(0)
 
 
 
 
 
 
52
  with torch.no_grad():
53
  outputs = model(processed_image)
54
  probabilities = F.softmax(outputs, dim=1)
55
+ confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
 
 
56
  return confidences
57
 
58
+ # === Results Viewer Function ===
59
+ def show_results(input_image: Image.Image):
60
+ preds = inference(input_image)
 
 
 
 
 
61
 
62
+ # Load plots if they exist
63
+ perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None
64
+ acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None
65
+
66
+ # Load final test accuracy number if available
67
+ test_acc_text = "Final Test Accuracy plot not found."
68
+ if acc_plot and os.path.exists("final_test_accuracy.png"):
69
+ # You can optionally parse accuracy from a saved file; here we just show a placeholder
70
+ test_acc_text = "See bar chart above for final test accuracy."
71
+
72
+ return preds, perf_plot, acc_plot, test_acc_text
73
+
74
+ # === Gradio Interface Setup ===
75
+ example_images = []
76
 
77
  interface = gr.Interface(
78
+ fn=show_results,
79
+ inputs=gr.Image(type='pil', label='Upload Image'),
80
+ outputs=[
81
+ gr.Label(num_top_classes=3, label='Predictions'),
82
+ gr.Image(type='filepath', label='Training Performance'),
83
+ gr.Image(type='filepath', label='Final Test Accuracy Plot'),
84
+ gr.Textbox(label='Final Test Accuracy')
85
+ ],
86
  title='CIFAR-10 Image Classification with DCLR Optimizer',
87
+ description='Upload an image to see predictions. Training/test plots show benchmark results on CIFAR-10.',
88
  examples=example_images
89
  )
90
 
 
91
  if __name__ == '__main__':
92
  interface.launch()