Update app.py
Browse files
app.py
CHANGED
|
@@ -48,9 +48,14 @@ class Model:
|
|
| 48 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
|
| 50 |
self.pipe.to(self.device)
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
|
| 53 |
self.pipe_img.to(self.device)
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def to_glb(self, ply_path: str) -> str:
|
| 56 |
mesh = trimesh.load(ply_path)
|
|
@@ -288,7 +293,7 @@ def generate(
|
|
| 288 |
):
|
| 289 |
"""
|
| 290 |
Generates chatbot responses with support for multimodal input, TTS, image generation,
|
| 291 |
-
and
|
| 292 |
|
| 293 |
Special commands:
|
| 294 |
- "@tts1" or "@tts2": triggers text-to-speech.
|
|
|
|
| 48 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 49 |
self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
|
| 50 |
self.pipe.to(self.device)
|
| 51 |
+
# Ensure the text encoder is in half precision to avoid dtype mismatches.
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
self.pipe.text_encoder = self.pipe.text_encoder.half()
|
| 54 |
|
| 55 |
self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
|
| 56 |
self.pipe_img.to(self.device)
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
self.pipe_img.text_encoder = self.pipe_img.text_encoder.half()
|
| 59 |
|
| 60 |
def to_glb(self, ply_path: str) -> str:
|
| 61 |
mesh = trimesh.load(ply_path)
|
|
|
|
| 293 |
):
|
| 294 |
"""
|
| 295 |
Generates chatbot responses with support for multimodal input, TTS, image generation,
|
| 296 |
+
and 3D model generation.
|
| 297 |
|
| 298 |
Special commands:
|
| 299 |
- "@tts1" or "@tts2": triggers text-to-speech.
|