SanskarModi commited on
Commit
69ec3b8
·
verified ·
1 Parent(s): aec0356

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +7 -8
prediction.py CHANGED
@@ -301,14 +301,14 @@ class Prediction:
301
  input_tensor = torch.stack(frames).unsqueeze(0)
302
  input_tensor = input_tensor.view(1, target_seq_length, 3, *self.resolution)
303
  input_tensor = input_tensor.to(self.device)
304
- input_tensor.requires_grad_()
305
-
306
- # Forward pass
307
- with torch.no_grad():
308
- fmap, attn_wts, logits = self.model(input_tensor)
309
-
 
310
  # Register hook for Grad-CAM
311
- fmap.requires_grad_()
312
  fmap.register_hook(self.save_gradients)
313
 
314
  # Get predictions for all classes
@@ -358,4 +358,3 @@ class Prediction:
358
  gradcam_image = None
359
 
360
  return prediction_string, gradcam_image, classification_details
361
-
 
301
  input_tensor = torch.stack(frames).unsqueeze(0)
302
  input_tensor = input_tensor.view(1, target_seq_length, 3, *self.resolution)
303
  input_tensor = input_tensor.to(self.device)
304
+
305
+ # Remove the torch.no_grad() context to allow gradient computation
306
+ input_tensor.requires_grad_(True)
307
+
308
+ # Forward pass with gradient tracking enabled
309
+ fmap, attn_wts, logits = self.model(input_tensor)
310
+
311
  # Register hook for Grad-CAM
 
312
  fmap.register_hook(self.save_gradients)
313
 
314
  # Get predictions for all classes
 
358
  gradcam_image = None
359
 
360
  return prediction_string, gradcam_image, classification_details