ifire commited on
Commit
fbf185f
·
1 Parent(s): 002a2e8

Refactor predictor to use Path and tempfile

Browse files
Files changed (1) hide show
  1. predict.py +10 -10
predict.py CHANGED
@@ -1,10 +1,11 @@
1
- from cog import BasePredictor, Input, File
2
- import torch
3
  import time
4
  from meshgpt_pytorch import MeshTransformer, mesh_render
5
  import igl
6
  import numpy as np
7
 
 
8
  class Predictor(BasePredictor):
9
  def setup(self):
10
  """Load the model into memory to make running multiple predictions efficient"""
@@ -18,14 +19,15 @@ class Predictor(BasePredictor):
18
  c, _ = igl.orientable_patches(f)
19
  f, _ = igl.orient_outward(v, f, c)
20
  igl.write_triangle_mesh(file_path, v, f)
21
- return file_path
 
22
 
23
  def predict(
24
  self,
25
  text: str = Input(description="Enter labels, separated by commas"),
26
  num_input: int = Input(description="Number of examples per input", default=1),
27
  num_temp: float = Input(description="Temperature (0 to 1)", default=0),
28
- ) -> str:
29
  """Run a single prediction on the model"""
30
  self.transformer.eval()
31
  labels = [label.strip() for label in text.split(",")]
@@ -59,12 +61,10 @@ class Predictor(BasePredictor):
59
  output.append(
60
  (self.transformer.generate(texts=labels, temperature=num_temp))
61
  )
62
-
63
- mesh_render.save_rendering("./render.obj", output)
64
- file_path = self.save_as_obj("./render.obj")
65
-
66
- with open(file_path, 'rb') as file:
67
- return file.read()
68
 
69
 
70
  if __name__ == "__main__":
 
1
+ from cog import BasePredictor, Input, Path
2
+ import tempfile
3
  import time
4
  from meshgpt_pytorch import MeshTransformer, mesh_render
5
  import igl
6
  import numpy as np
7
 
8
+
9
  class Predictor(BasePredictor):
10
  def setup(self):
11
  """Load the model into memory to make running multiple predictions efficient"""
 
19
  c, _ = igl.orientable_patches(f)
20
  f, _ = igl.orient_outward(v, f, c)
21
  igl.write_triangle_mesh(file_path, v, f)
22
+ output_path = Path(tempfile.mkdtemp()) / file_path
23
+ return output_path
24
 
25
  def predict(
26
  self,
27
  text: str = Input(description="Enter labels, separated by commas"),
28
  num_input: int = Input(description="Number of examples per input", default=1),
29
  num_temp: float = Input(description="Temperature (0 to 1)", default=0),
30
+ ) -> Path:
31
  """Run a single prediction on the model"""
32
  self.transformer.eval()
33
  labels = [label.strip() for label in text.split(",")]
 
61
  output.append(
62
  (self.transformer.generate(texts=labels, temperature=num_temp))
63
  )
64
+ file_name = "./mesh.obj"
65
+ mesh_render.save_rendering(file_name, output)
66
+ file_path = self.save_as_obj(file_name)
67
+ return file_path
 
 
68
 
69
 
70
  if __name__ == "__main__":