oscarw-t commited on
Commit
64bc44f
·
1 Parent(s): fe0febc

fixed def predict

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -46,21 +46,15 @@ transform = transforms.Compose([
46
 
47
  # --- Prediction function ---
48
  def predict(image):
49
- """
50
- Takes any image (JPG, PNG, etc.), converts to RGB, resizes to 32x32,
51
- runs through the CNN, and returns class probabilities.
52
- """
53
- # Convert to RGB (in case of grayscale or RGBA input)
54
  image = image.convert("RGB")
55
- image = transform(image).unsqueeze(0) # shape: [1, 3, 32, 32]
56
-
57
  with torch.no_grad():
58
- outputs=gr.Label(num_top_classes=3)
59
- probs = torch.softmax(outputs, dim=1)[0]
60
-
61
- # Convert to dictionary: {class: probability}
62
  return {classes[i]: float(probs[i]) for i in range(10)}
63
 
 
64
  # --- Gradio Interface ---
65
  demo = gr.Interface(
66
  fn=predict,
 
46
 
47
  # --- Prediction function ---
48
  def predict(image):
 
 
 
 
 
49
  image = image.convert("RGB")
50
+ x = transform(image).unsqueeze(0) # (1, 3, 32, 32)
 
51
  with torch.no_grad():
52
+ outputs = model(x) # tensor shape [1, 10]
53
+ probs = torch.nn.functional.softmax(outputs, dim=1) # apply softmax
54
+ probs = probs[0].cpu().numpy() # convert to numpy for Gradio
 
55
  return {classes[i]: float(probs[i]) for i in range(10)}
56
 
57
+
58
  # --- Gradio Interface ---
59
  demo = gr.Interface(
60
  fn=predict,