Spaces:
Running
on
Zero
Running
on
Zero
update device
Browse files
app.py
CHANGED
|
@@ -41,10 +41,11 @@ def load_model(ckpt_path):
|
|
| 41 |
return model
|
| 42 |
|
| 43 |
def run_model(model, img: Image.Image):
|
|
|
|
| 44 |
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
|
| 45 |
-
image = to_tensor(img).to(
|
| 46 |
x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
|
| 47 |
-
with torch.no_grad() as no_grad, torch.autocast(device_type=
|
| 48 |
output = model(x)
|
| 49 |
return output
|
| 50 |
|
|
|
|
| 41 |
return model
|
| 42 |
|
| 43 |
def run_model(model, img: Image.Image):
|
| 44 |
+
device = next(model.parameters()).device
|
| 45 |
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
|
| 46 |
+
image = to_tensor(img).to(device)
|
| 47 |
x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
|
| 48 |
+
with torch.no_grad() as no_grad, torch.autocast(device_type=device.type) as amp:
|
| 49 |
output = model(x)
|
| 50 |
return output
|
| 51 |
|