File size: 3,489 Bytes
5cf4cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import base64
import io
import os
from typing import Optional
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

app = FastAPI()

# Global pipeline
pipe = None
export_to_video = None

class InferenceRequest(BaseModel):
    image: str  # base64 or URL
    prompt: str
    negative_prompt: str = "ugly, static, blurry, low quality"
    num_frames: int = 93
    num_inference_steps: int = 35
    guidance_scale: float = 7.0
    seed: Optional[int] = None

@app.on_event("startup")
async def load_model():
    global pipe, export_to_video
    from diffusers import Cosmos2VideoToWorldPipeline
    from diffusers.utils import export_to_video as etv
    
    export_to_video = etv
    model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
    
    print("Loading model...")
    pipe = Cosmos2VideoToWorldPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        token=os.environ.get("HF_TOKEN"),
    )
    pipe.to("cuda")
    print("Model loaded successfully!")

@app.post("/predict")
@app.post("/")
async def predict(request: dict):
    global pipe, export_to_video
    
    # Handle both direct and nested input formats
    inputs = request.get("inputs", request)
    
    image_data = inputs.get("image")
    if not image_data:
        raise HTTPException(status_code=400, detail="No image provided")
    
    prompt = inputs.get("prompt", "")
    if not prompt:
        raise HTTPException(status_code=400, detail="No prompt provided")
    
    # Load image
    try:
        if image_data.startswith("http"):
            from diffusers.utils import load_image
            image = load_image(image_data)
        else:
            image_bytes = base64.b64decode(image_data)
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        
        # Resize to expected dimensions for Cosmos Video2World
        image = image.resize((1280, 704))
        
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")

    negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality")
    num_frames = inputs.get("num_frames", 93)
    num_inference_steps = inputs.get("num_inference_steps", 35)
    guidance_scale = inputs.get("guidance_scale", 7.0)
    seed = inputs.get("seed")

    # Create generator on correct device
    generator = None
    if seed is not None:
        generator = torch.Generator(device="cuda").manual_seed(int(seed))

    try:
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            output = pipe(
                image=image,
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_frames=num_frames,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                generator=generator,
            )

        video_path = "/tmp/output.mp4"
        export_to_video(output.frames[0], video_path, fps=16)

        with open(video_path, "rb") as f:
            video_b64 = base64.b64encode(f.read()).decode("utf-8")

        return {"video": video_b64, "content_type": "video/mp4"}

    except Exception as e:
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")

@app.get("/health")
@app.get("/")
async def health():
    return {"status": "healthy", "message": "Cosmos-Predict2 Video2World API"}