Jagjeet2003 commited on
Commit
1cab240
·
verified ·
1 Parent(s): 37abcf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -1,30 +1,32 @@
1
  import torch
2
  import gradio as gr
3
- from model import ALexNet # Make sure this matches your actual class name
4
  from torchvision import transforms
5
  from PIL import Image
6
 
7
- # Load model
8
- model = ALexNet(3, 64, 10)
9
- model.load_state_dict(torch.load("Modified_ALexnet_for_CIFAR.pth", map_location=torch.device("cpu")))
10
- model.eval()
 
 
 
 
 
 
 
11
 
12
- # Preprocessing
13
  transform = transforms.Compose([
14
- transforms.Resize((32, 32)), # Adjust to your model's input size
15
  transforms.ToTensor()
16
  ])
17
 
18
- # Inference function
19
  def predict(img):
20
  img = transform(img).unsqueeze(0)
21
  with torch.no_grad():
22
  outputs = model(img)
23
  predicted_class = torch.argmax(outputs, dim=1).item()
24
  class_names = ["airplane", "automobile", "bird", "cat", "deer",
25
- "dog", "frog", "horse", "ship", "truck"]
26
-
27
  return f"Predicted class: {class_names[predicted_class]}"
28
 
29
- # Gradio UI
30
  gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()
 
1
  import torch
2
  import gradio as gr
 
3
  from torchvision import transforms
4
  from PIL import Image
5
 
6
+ from model import ALexNet # Make sure this file and class exist
7
+
8
+ print("App is starting...")
9
+
10
+ try:
11
+ model = ALexNet(3, 64, 10)
12
+ model.load_state_dict(torch.load("Modified_ALexnet_for_CIFAR.pth", map_location=torch.device("cpu")))
13
+ model.eval()
14
+ print("Model loaded successfully.")
15
+ except Exception as e:
16
+ print(f"Failed to load model: {e}")
17
 
 
18
  transform = transforms.Compose([
19
+ transforms.Resize((32, 32)),
20
  transforms.ToTensor()
21
  ])
22
 
 
23
  def predict(img):
24
  img = transform(img).unsqueeze(0)
25
  with torch.no_grad():
26
  outputs = model(img)
27
  predicted_class = torch.argmax(outputs, dim=1).item()
28
  class_names = ["airplane", "automobile", "bird", "cat", "deer",
29
+ "dog", "frog", "horse", "ship", "truck"]
 
30
  return f"Predicted class: {class_names[predicted_class]}"
31
 
 
32
  gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()