samwell commited on
Commit
a44999a
·
verified ·
1 Parent(s): fda3463

Add handler.py

Browse files
Files changed (1) hide show
  1. handler.py +71 -0
handler.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import base64
3
+ import io
4
+ from typing import Dict, Any
5
+ from PIL import Image
6
+
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path: str = ""):
10
+ from diffusers import Cosmos2VideoToWorldPipeline
11
+ from diffusers.utils import export_to_video
12
+
13
+ self.export_to_video = export_to_video
14
+ model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
15
+
16
+ self.pipe = Cosmos2VideoToWorldPipeline.from_pretrained(
17
+ model_id,
18
+ torch_dtype=torch.bfloat16,
19
+ )
20
+ self.pipe.to("cuda")
21
+
22
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
+ inputs = data.get("inputs", data)
24
+
25
+ image_data = inputs.get("image")
26
+ if not image_data:
27
+ return {"error": "No image provided"}
28
+
29
+ try:
30
+ if image_data.startswith("http"):
31
+ from diffusers.utils import load_image
32
+ image = load_image(image_data)
33
+ else:
34
+ image_bytes = base64.b64decode(image_data)
35
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
36
+ except Exception as e:
37
+ return {"error": f"Failed to load image: {str(e)}"}
38
+
39
+ prompt = inputs.get("prompt", "")
40
+ if not prompt:
41
+ return {"error": "No prompt provided"}
42
+
43
+ negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality")
44
+ num_frames = inputs.get("num_frames", 93)
45
+ num_inference_steps = inputs.get("num_inference_steps", 35)
46
+ guidance_scale = inputs.get("guidance_scale", 7.0)
47
+ seed = inputs.get("seed")
48
+
49
+ generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed else None
50
+
51
+ try:
52
+ output = self.pipe(
53
+ image=image,
54
+ prompt=prompt,
55
+ negative_prompt=negative_prompt,
56
+ num_frames=num_frames,
57
+ num_inference_steps=num_inference_steps,
58
+ guidance_scale=guidance_scale,
59
+ generator=generator,
60
+ )
61
+
62
+ video_path = "/tmp/output.mp4"
63
+ self.export_to_video(output.frames[0], video_path, fps=16)
64
+
65
+ with open(video_path, "rb") as f:
66
+ video_b64 = base64.b64encode(f.read()).decode("utf-8")
67
+
68
+ return {"video": video_b64, "content_type": "video/mp4"}
69
+
70
+ except Exception as e:
71
+ return {"error": f"Inference failed: {str(e)}"}