Update app.py
Browse files
app.py
CHANGED
|
@@ -50,15 +50,22 @@ class Model:
|
|
| 50 |
self.pipe.to(self.device)
|
| 51 |
# Ensure the text encoder is in half precision to avoid dtype mismatches.
|
| 52 |
if torch.cuda.is_available():
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
|
| 56 |
self.pipe_img.to(self.device)
|
|
|
|
| 57 |
if torch.cuda.is_available():
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def to_glb(self, ply_path: str) -> str:
|
| 61 |
mesh = trimesh.load(ply_path)
|
|
|
|
| 62 |
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
|
| 63 |
mesh.apply_transform(rot)
|
| 64 |
rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
|
|
@@ -447,4 +454,4 @@ demo = gr.ChatInterface(
|
|
| 447 |
|
| 448 |
if __name__ == "__main__":
|
| 449 |
# To create a public link, set share=True in launch().
|
| 450 |
-
demo.queue(max_size=20).launch(share=True)
|
|
|
|
| 50 |
self.pipe.to(self.device)
|
| 51 |
# Ensure the text encoder is in half precision to avoid dtype mismatches.
|
| 52 |
if torch.cuda.is_available():
|
| 53 |
+
try:
|
| 54 |
+
self.pipe.text_encoder = self.pipe.text_encoder.half()
|
| 55 |
+
except AttributeError:
|
| 56 |
+
pass
|
| 57 |
|
| 58 |
self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
|
| 59 |
self.pipe_img.to(self.device)
|
| 60 |
+
# Use getattr with a default value to avoid AttributeError if text_encoder is missing.
|
| 61 |
if torch.cuda.is_available():
|
| 62 |
+
text_encoder_img = getattr(self.pipe_img, "text_encoder", None)
|
| 63 |
+
if text_encoder_img is not None:
|
| 64 |
+
self.pipe_img.text_encoder = text_encoder_img.half()
|
| 65 |
|
| 66 |
def to_glb(self, ply_path: str) -> str:
|
| 67 |
mesh = trimesh.load(ply_path)
|
| 68 |
+
# Rotate the mesh for proper orientation
|
| 69 |
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
|
| 70 |
mesh.apply_transform(rot)
|
| 71 |
rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
|
|
|
|
| 454 |
|
| 455 |
if __name__ == "__main__":
|
| 456 |
# To create a public link, set share=True in launch().
|
| 457 |
+
demo.queue(max_size=20).launch(share=True)
|