File size: 2,431 Bytes
a44999a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
import base64
import io
from typing import Dict, Any
from PIL import Image
class EndpointHandler:
def __init__(self, path: str = ""):
from diffusers import Cosmos2VideoToWorldPipeline
from diffusers.utils import export_to_video
self.export_to_video = export_to_video
model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
self.pipe = Cosmos2VideoToWorldPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
)
self.pipe.to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", data)
image_data = inputs.get("image")
if not image_data:
return {"error": "No image provided"}
try:
if image_data.startswith("http"):
from diffusers.utils import load_image
image = load_image(image_data)
else:
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Failed to load image: {str(e)}"}
prompt = inputs.get("prompt", "")
if not prompt:
return {"error": "No prompt provided"}
negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality")
num_frames = inputs.get("num_frames", 93)
num_inference_steps = inputs.get("num_inference_steps", 35)
guidance_scale = inputs.get("guidance_scale", 7.0)
seed = inputs.get("seed")
generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed else None
try:
output = self.pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
)
video_path = "/tmp/output.mp4"
self.export_to_video(output.frames[0], video_path, fps=16)
with open(video_path, "rb") as f:
video_b64 = base64.b64encode(f.read()).decode("utf-8")
return {"video": video_b64, "content_type": "video/mp4"}
except Exception as e:
return {"error": f"Inference failed: {str(e)}"}
|