|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
import os |
|
|
from typing import Optional |
|
|
from PIL import Image |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import gc |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
pipe = None |
|
|
export_to_video = None |
|
|
DEVICE = "cuda" |
|
|
|
|
|
class InferenceRequest(BaseModel): |
|
|
image: str |
|
|
prompt: str |
|
|
negative_prompt: str = "ugly, static, blurry, low quality" |
|
|
num_frames: int = 93 |
|
|
num_inference_steps: int = 35 |
|
|
guidance_scale: float = 7.0 |
|
|
seed: Optional[int] = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
global pipe, export_to_video, DEVICE |
|
|
from diffusers import Cosmos2VideoToWorldPipeline |
|
|
from diffusers.utils import export_to_video as etv |
|
|
|
|
|
export_to_video = etv |
|
|
model_id = "nvidia/Cosmos-Predict2-2B-Video2World" |
|
|
|
|
|
print("Loading model...") |
|
|
pipe = Cosmos2VideoToWorldPipeline.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.bfloat16, |
|
|
token=os.environ.get("HF_TOKEN"), |
|
|
) |
|
|
pipe = pipe.to(DEVICE) |
|
|
print("Model loaded successfully!") |
|
|
print(f"Pipeline device: {pipe.device}") |
|
|
|
|
|
def ensure_on_device(): |
|
|
"""Ensure all pipeline components are on CUDA before inference""" |
|
|
global pipe, DEVICE |
|
|
pipe = pipe.to(DEVICE) |
|
|
|
|
|
if hasattr(pipe, 'text_encoder') and pipe.text_encoder is not None: |
|
|
pipe.text_encoder = pipe.text_encoder.to(DEVICE) |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
@app.post("/predict") |
|
|
@app.post("/") |
|
|
async def predict(request: dict): |
|
|
global pipe, export_to_video, DEVICE |
|
|
|
|
|
|
|
|
inputs = request.get("inputs", request) |
|
|
|
|
|
image_data = inputs.get("image") |
|
|
if not image_data: |
|
|
raise HTTPException(status_code=400, detail="No image provided") |
|
|
|
|
|
prompt = inputs.get("prompt", "") |
|
|
if not prompt: |
|
|
raise HTTPException(status_code=400, detail="No prompt provided") |
|
|
|
|
|
|
|
|
try: |
|
|
if image_data.startswith("http"): |
|
|
from diffusers.utils import load_image |
|
|
image = load_image(image_data) |
|
|
else: |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
image = image.resize((1280, 704), Image.Resampling.LANCZOS) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}") |
|
|
|
|
|
negative_prompt = inputs.get("negative_prompt", "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering.") |
|
|
num_frames = inputs.get("num_frames", 93) |
|
|
num_inference_steps = inputs.get("num_inference_steps", 35) |
|
|
guidance_scale = inputs.get("guidance_scale", 7.0) |
|
|
|
|
|
try: |
|
|
|
|
|
ensure_on_device() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
output = pipe( |
|
|
image=image, |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_frames=num_frames, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
) |
|
|
|
|
|
video_path = "/tmp/output.mp4" |
|
|
export_to_video(output.frames[0], video_path, fps=16) |
|
|
|
|
|
with open(video_path, "rb") as f: |
|
|
video_b64 = base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
return {"video": video_b64, "content_type": "video/mp4"} |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") |
|
|
|
|
|
@app.get("/health") |
|
|
@app.get("/") |
|
|
async def health(): |
|
|
return {"status": "healthy", "message": "Cosmos-Predict2 Video2World API"} |
|
|
|