Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -49,6 +49,7 @@ class Preprocessor:
|
|
| 49 |
elif name == "NormalBae":
|
| 50 |
print("Loading NormalBae")
|
| 51 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 52 |
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to(device)
|
| 53 |
if torch.cuda.is_available():
|
| 54 |
torch.cuda.empty_cache()
|
|
@@ -57,11 +58,35 @@ class Preprocessor:
|
|
| 57 |
raise ValueError
|
| 58 |
return
|
| 59 |
|
|
|
|
| 60 |
def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
|
| 61 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Load models and preprocessor when the script starts
|
| 67 |
# Controlnet Normal
|
|
|
|
| 49 |
elif name == "NormalBae":
|
| 50 |
print("Loading NormalBae")
|
| 51 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 52 |
+
# The .to(device) call should add the .device attribute
|
| 53 |
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to(device)
|
| 54 |
if torch.cuda.is_available():
|
| 55 |
torch.cuda.empty_cache()
|
|
|
|
| 58 |
raise ValueError
|
| 59 |
return
|
| 60 |
|
| 61 |
+
# Modified __call__ method to safely check device
|
| 62 |
def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
|
| 63 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 64 |
+
|
| 65 |
+
# Check if the model has the 'device' attribute before accessing its 'type'
|
| 66 |
+
if hasattr(self.model, 'device'):
|
| 67 |
+
if self.model.device.type != device:
|
| 68 |
+
print(f"Moving preprocessor model to {device}") # Debug print
|
| 69 |
+
# Attempt to move the model if it's not on the correct device
|
| 70 |
+
try:
|
| 71 |
+
self.model.to(device)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error moving preprocessor model to {device}: {e}")
|
| 74 |
+
# If moving fails, the next call might fail with a device mismatch
|
| 75 |
+
pass # Continue and let the next line potentially raise an error
|
| 76 |
+
else:
|
| 77 |
+
# If .device attribute is missing, assume it's not on the correct device and try to move it
|
| 78 |
+
# This might happen if NormalBaeDetector is not a standard torch.nn.Module
|
| 79 |
+
print("Warning: Preprocessor model has no .device attribute. Attempting to move to correct device.")
|
| 80 |
+
try:
|
| 81 |
+
self.model.to(device) # Attempt to move even if .device is missing
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Error attempting to move preprocessor model without .device attribute: {e}")
|
| 84 |
+
# If moving fails, the next call will likely fail with a device mismatch
|
| 85 |
+
pass # Continue and let the next line potentially raise an error
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
return self.model(image, **kwargs) # This call might fail if the model isn't on the correct device
|
| 89 |
+
|
| 90 |
|
| 91 |
# Load models and preprocessor when the script starts
|
| 92 |
# Controlnet Normal
|