File size: 4,414 Bytes
4233b79
 
 
 
 
 
 
 
d95102e
4233b79
 
 
 
 
 
850ecac
4233b79
 
 
 
 
 
 
 
 
 
 
 
850ecac
4233b79
 
 
 
 
 
 
 
 
 
 
 
850ecac
4233b79
850ecac
4233b79
d95102e
 
 
 
 
 
 
 
 
 
4233b79
 
 
850ecac
4233b79
 
 
 
 
 
 
 
 
 
 
 
850ecac
4233b79
 
850ecac
4233b79
 
 
 
2b2bf3c
0e4048f
0c995b9
2b2bf3c
4233b79
 
 
0e4048f
4233b79
 
 
 
 
d95102e
 
 
850ecac
 
 
 
 
 
 
 
 
4233b79
 
 
 
 
 
 
d95102e
 
 
 
4233b79
 
 
2b2bf3c
 
4233b79
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import base64
import io
import os
from typing import Optional
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import gc

app = FastAPI()

# Global pipeline
pipe = None
export_to_video = None
DEVICE = "cuda"

class InferenceRequest(BaseModel):
    image: str  # base64 or URL
    prompt: str
    negative_prompt: str = "ugly, static, blurry, low quality"
    num_frames: int = 93
    num_inference_steps: int = 35
    guidance_scale: float = 7.0
    seed: Optional[int] = None

@app.on_event("startup")
async def load_model():
    global pipe, export_to_video, DEVICE
    from diffusers import Cosmos2VideoToWorldPipeline
    from diffusers.utils import export_to_video as etv
    
    export_to_video = etv
    model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
    
    print("Loading model...")
    pipe = Cosmos2VideoToWorldPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        token=os.environ.get("HF_TOKEN"),
    )
    pipe = pipe.to(DEVICE)
    print("Model loaded successfully!")
    print(f"Pipeline device: {pipe.device}")

def ensure_on_device():
    """Ensure all pipeline components are on CUDA before inference"""
    global pipe, DEVICE
    pipe = pipe.to(DEVICE)
    # Force text_encoder to CUDA (this is the problematic component)
    if hasattr(pipe, 'text_encoder') and pipe.text_encoder is not None:
        pipe.text_encoder = pipe.text_encoder.to(DEVICE)
    torch.cuda.empty_cache()
    gc.collect()

@app.post("/predict")
@app.post("/")
async def predict(request: dict):
    global pipe, export_to_video, DEVICE
    
    # Handle both direct and nested input formats
    inputs = request.get("inputs", request)
    
    image_data = inputs.get("image")
    if not image_data:
        raise HTTPException(status_code=400, detail="No image provided")
    
    prompt = inputs.get("prompt", "")
    if not prompt:
        raise HTTPException(status_code=400, detail="No prompt provided")
    
    # Load image
    try:
        if image_data.startswith("http"):
            from diffusers.utils import load_image
            image = load_image(image_data)
        else:
            image_bytes = base64.b64decode(image_data)
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        
        # Resize to expected dimensions for Cosmos Video2World (720P model)
        image = image.resize((1280, 704), Image.Resampling.LANCZOS)
        
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")

    negative_prompt = inputs.get("negative_prompt", "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering.")
    num_frames = inputs.get("num_frames", 93)
    num_inference_steps = inputs.get("num_inference_steps", 35)
    guidance_scale = inputs.get("guidance_scale", 7.0)

    try:
        # Ensure all components on CUDA before each inference
        ensure_on_device()
        
        with torch.inference_mode():
            output = pipe(
                image=image,
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_frames=num_frames,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
            )

        video_path = "/tmp/output.mp4"
        export_to_video(output.frames[0], video_path, fps=16)

        with open(video_path, "rb") as f:
            video_b64 = base64.b64encode(f.read()).decode("utf-8")

        # Clean up after inference
        torch.cuda.empty_cache()
        gc.collect()

        return {"video": video_b64, "content_type": "video/mp4"}

    except Exception as e:
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")

@app.get("/health")
@app.get("/")
async def health():
    return {"status": "healthy", "message": "Cosmos-Predict2 Video2World API"}