Spaces:
Sleeping
Sleeping
Update prediction.py
Browse files- prediction.py +7 -8
prediction.py
CHANGED
|
@@ -301,14 +301,14 @@ class Prediction:
|
|
| 301 |
input_tensor = torch.stack(frames).unsqueeze(0)
|
| 302 |
input_tensor = input_tensor.view(1, target_seq_length, 3, *self.resolution)
|
| 303 |
input_tensor = input_tensor.to(self.device)
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
| 310 |
# Register hook for Grad-CAM
|
| 311 |
-
fmap.requires_grad_()
|
| 312 |
fmap.register_hook(self.save_gradients)
|
| 313 |
|
| 314 |
# Get predictions for all classes
|
|
@@ -358,4 +358,3 @@ class Prediction:
|
|
| 358 |
gradcam_image = None
|
| 359 |
|
| 360 |
return prediction_string, gradcam_image, classification_details
|
| 361 |
-
|
|
|
|
| 301 |
input_tensor = torch.stack(frames).unsqueeze(0)
|
| 302 |
input_tensor = input_tensor.view(1, target_seq_length, 3, *self.resolution)
|
| 303 |
input_tensor = input_tensor.to(self.device)
|
| 304 |
+
|
| 305 |
+
# Remove the torch.no_grad() context to allow gradient computation
|
| 306 |
+
input_tensor.requires_grad_(True)
|
| 307 |
+
|
| 308 |
+
# Forward pass with gradient tracking enabled
|
| 309 |
+
fmap, attn_wts, logits = self.model(input_tensor)
|
| 310 |
+
|
| 311 |
# Register hook for Grad-CAM
|
|
|
|
| 312 |
fmap.register_hook(self.save_gradients)
|
| 313 |
|
| 314 |
# Get predictions for all classes
|
|
|
|
| 358 |
gradcam_image = None
|
| 359 |
|
| 360 |
return prediction_string, gradcam_image, classification_details
|
|
|