Varhal commited on
Commit
a9d2bac
·
verified ·
1 Parent(s): bfd83e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -3
app.py CHANGED
@@ -49,6 +49,7 @@ class Preprocessor:
49
  elif name == "NormalBae":
50
  print("Loading NormalBae")
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
52
  self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to(device)
53
  if torch.cuda.is_available():
54
  torch.cuda.empty_cache()
@@ -57,11 +58,35 @@ class Preprocessor:
57
  raise ValueError
58
  return
59
 
 
60
  def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
61
  device = "cuda" if torch.cuda.is_available() else "cpu"
62
- if self.model.device.type != device:
63
- self.model.to(device)
64
- return self.model(image, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Load models and preprocessor when the script starts
67
  # Controlnet Normal
 
49
  elif name == "NormalBae":
50
  print("Loading NormalBae")
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ # The .to(device) call should add the .device attribute
53
  self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to(device)
54
  if torch.cuda.is_available():
55
  torch.cuda.empty_cache()
 
58
  raise ValueError
59
  return
60
 
61
+ # Modified __call__ method to safely check device
62
  def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
64
+
65
+ # Check if the model has the 'device' attribute before accessing its 'type'
66
+ if hasattr(self.model, 'device'):
67
+ if self.model.device.type != device:
68
+ print(f"Moving preprocessor model to {device}") # Debug print
69
+ # Attempt to move the model if it's not on the correct device
70
+ try:
71
+ self.model.to(device)
72
+ except Exception as e:
73
+ print(f"Error moving preprocessor model to {device}: {e}")
74
+ # If moving fails, the next call might fail with a device mismatch
75
+ pass # Continue and let the next line potentially raise an error
76
+ else:
77
+ # If .device attribute is missing, assume it's not on the correct device and try to move it
78
+ # This might happen if NormalBaeDetector is not a standard torch.nn.Module
79
+ print("Warning: Preprocessor model has no .device attribute. Attempting to move to correct device.")
80
+ try:
81
+ self.model.to(device) # Attempt to move even if .device is missing
82
+ except Exception as e:
83
+ print(f"Error attempting to move preprocessor model without .device attribute: {e}")
84
+ # If moving fails, the next call will likely fail with a device mismatch
85
+ pass # Continue and let the next line potentially raise an error
86
+
87
+
88
+ return self.model(image, **kwargs) # This call might fail if the model isn't on the correct device
89
+
90
 
91
  # Load models and preprocessor when the script starts
92
  # Controlnet Normal