| import torch | |
| import base64 | |
| import io | |
| from typing import Dict, Any | |
| from PIL import Image | |
| class EndpointHandler: | |
| def __init__(self, path: str = ""): | |
| from diffusers import Cosmos2VideoToWorldPipeline | |
| from diffusers.utils import export_to_video | |
| self.export_to_video = export_to_video | |
| model_id = "nvidia/Cosmos-Predict2-2B-Video2World" | |
| self.pipe = Cosmos2VideoToWorldPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| self.pipe.to("cuda") | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| inputs = data.get("inputs", data) | |
| image_data = inputs.get("image") | |
| if not image_data: | |
| return {"error": "No image provided"} | |
| try: | |
| if image_data.startswith("http"): | |
| from diffusers.utils import load_image | |
| image = load_image(image_data) | |
| else: | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except Exception as e: | |
| return {"error": f"Failed to load image: {str(e)}"} | |
| prompt = inputs.get("prompt", "") | |
| if not prompt: | |
| return {"error": "No prompt provided"} | |
| negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality") | |
| num_frames = inputs.get("num_frames", 93) | |
| num_inference_steps = inputs.get("num_inference_steps", 35) | |
| guidance_scale = inputs.get("guidance_scale", 7.0) | |
| seed = inputs.get("seed") | |
| generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed else None | |
| try: | |
| output = self.pipe( | |
| image=image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| ) | |
| video_path = "/tmp/output.mp4" | |
| self.export_to_video(output.frames[0], video_path, fps=16) | |
| with open(video_path, "rb") as f: | |
| video_b64 = base64.b64encode(f.read()).decode("utf-8") | |
| return {"video": video_b64, "content_type": "video/mp4"} | |
| except Exception as e: | |
| return {"error": f"Inference failed: {str(e)}"} | |