Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
·
c0eeb7b
1
Parent(s):
8d2d202
hdhfds 44
Browse files- inference.py +12 -12
inference.py
CHANGED
|
@@ -45,8 +45,8 @@ class InferenceService:
|
|
| 45 |
# Disable gradients
|
| 46 |
for m in [self.resnet, self.vit]:
|
| 47 |
if m is not None:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
|
| 51 |
# Update overall status
|
| 52 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
@@ -177,8 +177,8 @@ class InferenceService:
|
|
| 177 |
# Disable gradients
|
| 178 |
for m in [self.resnet, self.vit]:
|
| 179 |
if m is not None:
|
| 180 |
-
|
| 181 |
-
|
| 182 |
|
| 183 |
# Update overall status
|
| 184 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
@@ -202,12 +202,12 @@ class InferenceService:
|
|
| 202 |
return []
|
| 203 |
|
| 204 |
try:
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
|
| 212 |
print(f"🔍 DEBUG: Successfully generated {len(result)} embeddings")
|
| 213 |
return result
|
|
@@ -658,8 +658,8 @@ class InferenceService:
|
|
| 658 |
# Disable gradients
|
| 659 |
for m in [self.resnet, self.vit]:
|
| 660 |
if m is not None:
|
| 661 |
-
|
| 662 |
-
|
| 663 |
|
| 664 |
# Update overall status
|
| 665 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
|
|
| 45 |
# Disable gradients
|
| 46 |
for m in [self.resnet, self.vit]:
|
| 47 |
if m is not None:
|
| 48 |
+
for p in m.parameters():
|
| 49 |
+
p.requires_grad_(False)
|
| 50 |
|
| 51 |
# Update overall status
|
| 52 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
|
|
| 177 |
# Disable gradients
|
| 178 |
for m in [self.resnet, self.vit]:
|
| 179 |
if m is not None:
|
| 180 |
+
for p in m.parameters():
|
| 181 |
+
p.requires_grad_(False)
|
| 182 |
|
| 183 |
# Update overall status
|
| 184 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|
|
|
|
| 202 |
return []
|
| 203 |
|
| 204 |
try:
|
| 205 |
+
batch = torch.stack([self.transform(img) for img in images])
|
| 206 |
+
batch = batch.to(self.device, memory_format=torch.channels_last)
|
| 207 |
+
use_amp = (self.device == "cuda")
|
| 208 |
+
with torch.autocast(device_type=("cuda" if use_amp else "cpu"), enabled=use_amp):
|
| 209 |
+
emb = self.resnet(batch)
|
| 210 |
+
emb = nn.functional.normalize(emb, dim=-1)
|
| 211 |
result = [e.detach().cpu().numpy().astype(np.float32) for e in emb]
|
| 212 |
print(f"🔍 DEBUG: Successfully generated {len(result)} embeddings")
|
| 213 |
return result
|
|
|
|
| 658 |
# Disable gradients
|
| 659 |
for m in [self.resnet, self.vit]:
|
| 660 |
if m is not None:
|
| 661 |
+
for p in m.parameters():
|
| 662 |
+
p.requires_grad_(False)
|
| 663 |
|
| 664 |
# Update overall status
|
| 665 |
self.models_loaded = self.resnet_loaded and self.vit_loaded
|