amar6de2 commited on
Commit
7a7cf02
·
1 Parent(s): 5f0b74a
Files changed (1) hide show
  1. app.py +27 -18
app.py CHANGED
@@ -42,30 +42,39 @@ vit.load_state_dict(torch.load("vit_epoch_2.pth", map_location=torch.device("cpu
42
 
43
  def predict(img) -> Tuple[Dict[str, float], float]:
44
  """Transforms and performs a prediction on img and returns prediction and time taken."""
45
- # Ensure the image is a PIL image
46
- if isinstance(img, np.ndarray):
47
- img = Image.fromarray(img)
48
 
49
- # Start the timer
50
- start_time = timer()
 
 
51
 
52
- # Transform the image and add batch dimension
53
- img = vit_transforms(img).unsqueeze(0)
 
54
 
55
- # Run inference
56
- vit.eval()
57
- with torch.inference_mode():
58
- pred_probs = torch.softmax(vit(img), dim=1)
59
 
60
- # Create label and probability dict
61
- pred_labels_and_probs = {
62
- class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
63
- }
64
 
65
- # Calculate prediction time
66
- pred_time = round(timer() - start_time, 5)
 
 
67
 
68
- return pred_labels_and_probs, pred_time
 
 
 
 
 
 
 
 
 
 
69
 
70
  ### 4. Gradio app setup ###
71
 
 
42
 
43
  def predict(img) -> Tuple[Dict[str, float], float]:
44
  """Transforms and performs a prediction on img and returns prediction and time taken."""
45
+ from PIL import UnidentifiedImageError
 
 
46
 
47
+ try:
48
+ # Convert ndarray to PIL if needed
49
+ if isinstance(img, np.ndarray):
50
+ img = Image.fromarray(img)
51
 
52
+ # Catch bad image input
53
+ if img.mode != "RGB":
54
+ img = img.convert("RGB")
55
 
56
+ # Start timer
57
+ start_time = timer()
 
 
58
 
59
+ # Transform and add batch dimension
60
+ img_tensor = vit_transforms(img).unsqueeze(0)
 
 
61
 
62
+ # Inference
63
+ vit.eval()
64
+ with torch.inference_mode():
65
+ pred_probs = torch.softmax(vit(img_tensor), dim=1)
66
 
67
+ pred_labels_and_probs = {
68
+ class_names[i]: float(pred_probs[0][i])
69
+ for i in range(len(class_names))
70
+ }
71
+
72
+ pred_time = round(timer() - start_time, 5)
73
+
74
+ return pred_labels_and_probs, pred_time
75
+
76
+ except (UnidentifiedImageError, TypeError, ValueError) as e:
77
+ return {"Error": f"Invalid image input: {str(e)}"}, 0.0
78
 
79
  ### 4. Gradio app setup ###
80