MetricMogul commited on
Commit
706ae75
·
verified ·
1 Parent(s): 7a1aa5e

Update shape_e_service.py

Browse files
Files changed (1) hide show
  1. shape_e_service.py +17 -7
shape_e_service.py CHANGED
@@ -1,6 +1,7 @@
1
  import gc
2
  import os
3
- import tempfile
 
4
  from typing import Any, Dict, List, Optional, Tuple
5
 
6
  import gradio as gr
@@ -13,6 +14,11 @@ pipe = None
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
15
 
 
 
 
 
 
16
  EXAMPLES = [
17
  "A cute stylized robot with a round head",
18
  "A fantasy treasure chest with gold trim",
@@ -37,13 +43,17 @@ def get_pipeline():
37
  return pipe
38
 
39
 
40
- def save_frames_to_files(frames) -> List[str]:
 
 
 
 
41
  frame_paths = []
42
- for frame in frames:
43
- fd, frame_path = tempfile.mkstemp(suffix=".png", prefix="shape_frame_")
44
- os.close(fd)
45
  frame.save(frame_path)
46
- frame_paths.append(frame_path)
 
47
  return frame_paths
48
 
49
 
@@ -136,7 +146,7 @@ def generate_and_add_asset(
136
  )
137
 
138
  frames = result.images[0]
139
- frame_paths = save_frames_to_files(frames)
140
 
141
  new_asset = make_asset(prompt, frame_paths)
142
  saved_assets = saved_assets + [new_asset]
 
1
  import gc
2
  import os
3
+ import uuid
4
+ from pathlib import Path
5
  from typing import Any, Dict, List, Optional, Tuple
6
 
7
  import gradio as gr
 
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
16
 
17
+ ROOT_DIR = Path(__file__).resolve().parent
18
+ DATA_DIR = ROOT_DIR / "data"
19
+ ASSETS_DIR = DATA_DIR / "assets"
20
+ ASSETS_DIR.mkdir(parents=True, exist_ok=True)
21
+
22
  EXAMPLES = [
23
  "A cute stylized robot with a round head",
24
  "A fantasy treasure chest with gold trim",
 
43
  return pipe
44
 
45
 
46
+ def save_frames_to_files(frames, prompt: str) -> List[str]:
47
+ asset_id = f"asset_{uuid.uuid4().hex[:8]}"
48
+ asset_dir = ASSETS_DIR / asset_id
49
+ asset_dir.mkdir(parents=True, exist_ok=True)
50
+
51
  frame_paths = []
52
+ for i, frame in enumerate(frames):
53
+ frame_path = asset_dir / f"view_{i:03d}.png"
 
54
  frame.save(frame_path)
55
+ frame_paths.append(str(frame_path))
56
+
57
  return frame_paths
58
 
59
 
 
146
  )
147
 
148
  frames = result.images[0]
149
+ frame_paths = save_frames_to_files(frames, prompt)
150
 
151
  new_asset = make_asset(prompt, frame_paths)
152
  saved_assets = saved_assets + [new_asset]