samwell commited on
Commit
dd3f963
·
verified ·
1 Parent(s): 01565f9

FastAPI handler for custom container

Browse files
Files changed (1) hide show
  1. handler.py +87 -53
handler.py CHANGED
@@ -1,71 +1,105 @@
1
  import torch
2
  import base64
3
  import io
4
- from typing import Dict, Any
 
5
  from PIL import Image
 
 
6
 
 
7
 
8
- class EndpointHandler:
9
- def __init__(self, path: str = ""):
10
- from diffusers import Cosmos2VideoToWorldPipeline
11
- from diffusers.utils import export_to_video
12
 
13
- self.export_to_video = export_to_video
14
- model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
 
 
 
 
 
 
15
 
16
- self.pipe = Cosmos2VideoToWorldPipeline.from_pretrained(
17
- model_id,
18
- torch_dtype=torch.bfloat16,
19
- )
20
- self.pipe.to("cuda")
21
-
22
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
- inputs = data.get("inputs", data)
24
 
25
- image_data = inputs.get("image")
26
- if not image_data:
27
- return {"error": "No image provided"}
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- try:
30
- if image_data.startswith("http"):
31
- from diffusers.utils import load_image
32
- image = load_image(image_data)
33
- else:
34
- image_bytes = base64.b64decode(image_data)
35
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
36
- except Exception as e:
37
- return {"error": f"Failed to load image: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- prompt = inputs.get("prompt", "")
40
- if not prompt:
41
- return {"error": "No prompt provided"}
 
 
42
 
43
- negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality")
44
- num_frames = inputs.get("num_frames", 93)
45
- num_inference_steps = inputs.get("num_inference_steps", 35)
46
- guidance_scale = inputs.get("guidance_scale", 7.0)
47
- seed = inputs.get("seed")
48
 
49
- generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed else None
 
 
 
 
 
 
 
 
 
50
 
51
- try:
52
- output = self.pipe(
53
- image=image,
54
- prompt=prompt,
55
- negative_prompt=negative_prompt,
56
- num_frames=num_frames,
57
- num_inference_steps=num_inference_steps,
58
- guidance_scale=guidance_scale,
59
- generator=generator,
60
- )
61
 
62
- video_path = "/tmp/output.mp4"
63
- self.export_to_video(output.frames[0], video_path, fps=16)
64
 
65
- with open(video_path, "rb") as f:
66
- video_b64 = base64.b64encode(f.read()).decode("utf-8")
67
 
68
- return {"video": video_b64, "content_type": "video/mp4"}
 
69
 
70
- except Exception as e:
71
- return {"error": f"Inference failed: {str(e)}"}
 
 
1
  import torch
2
  import base64
3
  import io
4
+ import os
5
+ from typing import Dict, Any, Optional
6
  from PIL import Image
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
 
10
+ app = FastAPI()
11
 
12
+ # Global pipeline
13
+ pipe = None
14
+ export_to_video = None
 
15
 
16
+ class InferenceRequest(BaseModel):
17
+ image: str # base64 or URL
18
+ prompt: str
19
+ negative_prompt: str = "ugly, static, blurry, low quality"
20
+ num_frames: int = 93
21
+ num_inference_steps: int = 35
22
+ guidance_scale: float = 7.0
23
+ seed: Optional[int] = None
24
 
25
+ class InferenceInputs(BaseModel):
26
+ inputs: InferenceRequest
 
 
 
 
 
 
27
 
28
+ @app.on_event("startup")
29
+ async def load_model():
30
+ global pipe, export_to_video
31
+ from diffusers import Cosmos2VideoToWorldPipeline
32
+ from diffusers.utils import export_to_video as etv
33
+
34
+ export_to_video = etv
35
+ model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
36
+
37
+ pipe = Cosmos2VideoToWorldPipeline.from_pretrained(
38
+ model_id,
39
+ torch_dtype=torch.bfloat16,
40
+ token=os.environ.get("HF_TOKEN"),
41
+ )
42
+ pipe.to("cuda")
43
+ print("Model loaded successfully!")
44
 
45
+ @app.post("/")
46
+ async def predict(request: dict):
47
+ global pipe, export_to_video
48
+
49
+ # Handle both direct and nested input formats
50
+ inputs = request.get("inputs", request)
51
+
52
+ image_data = inputs.get("image")
53
+ if not image_data:
54
+ raise HTTPException(status_code=400, detail="No image provided")
55
+
56
+ prompt = inputs.get("prompt", "")
57
+ if not prompt:
58
+ raise HTTPException(status_code=400, detail="No prompt provided")
59
+
60
+ # Load image
61
+ try:
62
+ if image_data.startswith("http"):
63
+ from diffusers.utils import load_image
64
+ image = load_image(image_data)
65
+ else:
66
+ image_bytes = base64.b64decode(image_data)
67
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
68
+ except Exception as e:
69
+ raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")
70
 
71
+ negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality")
72
+ num_frames = inputs.get("num_frames", 93)
73
+ num_inference_steps = inputs.get("num_inference_steps", 35)
74
+ guidance_scale = inputs.get("guidance_scale", 7.0)
75
+ seed = inputs.get("seed")
76
 
77
+ generator = None
78
+ if seed is not None:
79
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
 
 
80
 
81
+ try:
82
+ output = pipe(
83
+ image=image,
84
+ prompt=prompt,
85
+ negative_prompt=negative_prompt,
86
+ num_frames=num_frames,
87
+ num_inference_steps=num_inference_steps,
88
+ guidance_scale=guidance_scale,
89
+ generator=generator,
90
+ )
91
 
92
+ video_path = "/tmp/output.mp4"
93
+ export_to_video(output.frames[0], video_path, fps=16)
 
 
 
 
 
 
 
 
94
 
95
+ with open(video_path, "rb") as f:
96
+ video_b64 = base64.b64encode(f.read()).decode("utf-8")
97
 
98
+ return {"video": video_b64, "content_type": "video/mp4"}
 
99
 
100
+ except Exception as e:
101
+ raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
102
 
103
+ @app.get("/health")
104
+ async def health():
105
+ return {"status": "healthy"}