SanskarModi commited on
Commit
aec0356
·
verified ·
1 Parent(s): 397b9f3

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +2 -7
prediction.py CHANGED
@@ -303,14 +303,9 @@ class Prediction:
303
  input_tensor = input_tensor.to(self.device)
304
  input_tensor.requires_grad_()
305
 
306
- # Forward pass to get model output dict
307
  with torch.no_grad():
308
- output_dict = self.model(input_tensor)
309
-
310
- # Extract relevant outputs
311
- fmap = output_dict["fmap"]
312
- attn_wts = output_dict["attn"]
313
- logits = output_dict["logits"]
314
 
315
  # Register hook for Grad-CAM
316
  fmap.requires_grad_()
 
303
  input_tensor = input_tensor.to(self.device)
304
  input_tensor.requires_grad_()
305
 
306
+ # Forward pass
307
  with torch.no_grad():
308
+ fmap, attn_wts, logits = self.model(input_tensor)
 
 
 
 
 
309
 
310
  # Register hook for Grad-CAM
311
  fmap.requires_grad_()