Refactor predictor to use Path and tempfile
Browse files- predict.py +10 -10
predict.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
-
from cog import BasePredictor, Input,
|
| 2 |
-
import
|
| 3 |
import time
|
| 4 |
from meshgpt_pytorch import MeshTransformer, mesh_render
|
| 5 |
import igl
|
| 6 |
import numpy as np
|
| 7 |
|
|
|
|
| 8 |
class Predictor(BasePredictor):
|
| 9 |
def setup(self):
|
| 10 |
"""Load the model into memory to make running multiple predictions efficient"""
|
|
@@ -18,14 +19,15 @@ class Predictor(BasePredictor):
|
|
| 18 |
c, _ = igl.orientable_patches(f)
|
| 19 |
f, _ = igl.orient_outward(v, f, c)
|
| 20 |
igl.write_triangle_mesh(file_path, v, f)
|
| 21 |
-
|
|
|
|
| 22 |
|
| 23 |
def predict(
|
| 24 |
self,
|
| 25 |
text: str = Input(description="Enter labels, separated by commas"),
|
| 26 |
num_input: int = Input(description="Number of examples per input", default=1),
|
| 27 |
num_temp: float = Input(description="Temperature (0 to 1)", default=0),
|
| 28 |
-
) ->
|
| 29 |
"""Run a single prediction on the model"""
|
| 30 |
self.transformer.eval()
|
| 31 |
labels = [label.strip() for label in text.split(",")]
|
|
@@ -59,12 +61,10 @@ class Predictor(BasePredictor):
|
|
| 59 |
output.append(
|
| 60 |
(self.transformer.generate(texts=labels, temperature=num_temp))
|
| 61 |
)
|
| 62 |
-
|
| 63 |
-
mesh_render.save_rendering(
|
| 64 |
-
file_path = self.save_as_obj(
|
| 65 |
-
|
| 66 |
-
with open(file_path, 'rb') as file:
|
| 67 |
-
return file.read()
|
| 68 |
|
| 69 |
|
| 70 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
from cog import BasePredictor, Input, Path
|
| 2 |
+
import tempfile
|
| 3 |
import time
|
| 4 |
from meshgpt_pytorch import MeshTransformer, mesh_render
|
| 5 |
import igl
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
+
|
| 9 |
class Predictor(BasePredictor):
|
| 10 |
def setup(self):
|
| 11 |
"""Load the model into memory to make running multiple predictions efficient"""
|
|
|
|
| 19 |
c, _ = igl.orientable_patches(f)
|
| 20 |
f, _ = igl.orient_outward(v, f, c)
|
| 21 |
igl.write_triangle_mesh(file_path, v, f)
|
| 22 |
+
output_path = Path(tempfile.mkdtemp()) / file_path
|
| 23 |
+
return output_path
|
| 24 |
|
| 25 |
def predict(
|
| 26 |
self,
|
| 27 |
text: str = Input(description="Enter labels, separated by commas"),
|
| 28 |
num_input: int = Input(description="Number of examples per input", default=1),
|
| 29 |
num_temp: float = Input(description="Temperature (0 to 1)", default=0),
|
| 30 |
+
) -> Path:
|
| 31 |
"""Run a single prediction on the model"""
|
| 32 |
self.transformer.eval()
|
| 33 |
labels = [label.strip() for label in text.split(",")]
|
|
|
|
| 61 |
output.append(
|
| 62 |
(self.transformer.generate(texts=labels, temperature=num_temp))
|
| 63 |
)
|
| 64 |
+
file_name = "./mesh.obj"
|
| 65 |
+
mesh_render.save_rendering(file_name, output)
|
| 66 |
+
file_path = self.save_as_obj(file_name)
|
| 67 |
+
return file_path
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
if __name__ == "__main__":
|