File size: 3,489 Bytes
5cf4cc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
import base64
import io
import os
from typing import Optional
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
app = FastAPI()
# Global pipeline
pipe = None
export_to_video = None
class InferenceRequest(BaseModel):
image: str # base64 or URL
prompt: str
negative_prompt: str = "ugly, static, blurry, low quality"
num_frames: int = 93
num_inference_steps: int = 35
guidance_scale: float = 7.0
seed: Optional[int] = None
@app.on_event("startup")
async def load_model():
global pipe, export_to_video
from diffusers import Cosmos2VideoToWorldPipeline
from diffusers.utils import export_to_video as etv
export_to_video = etv
model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
print("Loading model...")
pipe = Cosmos2VideoToWorldPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
token=os.environ.get("HF_TOKEN"),
)
pipe.to("cuda")
print("Model loaded successfully!")
@app.post("/predict")
@app.post("/")
async def predict(request: dict):
global pipe, export_to_video
# Handle both direct and nested input formats
inputs = request.get("inputs", request)
image_data = inputs.get("image")
if not image_data:
raise HTTPException(status_code=400, detail="No image provided")
prompt = inputs.get("prompt", "")
if not prompt:
raise HTTPException(status_code=400, detail="No prompt provided")
# Load image
try:
if image_data.startswith("http"):
from diffusers.utils import load_image
image = load_image(image_data)
else:
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Resize to expected dimensions for Cosmos Video2World
image = image.resize((1280, 704))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")
negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality")
num_frames = inputs.get("num_frames", 93)
num_inference_steps = inputs.get("num_inference_steps", 35)
guidance_scale = inputs.get("guidance_scale", 7.0)
seed = inputs.get("seed")
# Create generator on correct device
generator = None
if seed is not None:
generator = torch.Generator(device="cuda").manual_seed(int(seed))
try:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
)
video_path = "/tmp/output.mp4"
export_to_video(output.frames[0], video_path, fps=16)
with open(video_path, "rb") as f:
video_b64 = base64.b64encode(f.read()).decode("utf-8")
return {"video": video_b64, "content_type": "video/mp4"}
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
@app.get("/health")
@app.get("/")
async def health():
return {"status": "healthy", "message": "Cosmos-Predict2 Video2World API"}
|