samwell commited on
Commit
5cf4cc1
·
verified ·
1 Parent(s): 4263d53

Fix CUDA device mismatch - resize image and add autocast

Browse files
Files changed (1) hide show
  1. handler.py +113 -0
handler.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import base64
3
+ import io
4
+ import os
5
+ from typing import Optional
6
+ from PIL import Image
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+
10
+ app = FastAPI()
11
+
12
+ # Global pipeline
13
+ pipe = None
14
+ export_to_video = None
15
+
16
+ class InferenceRequest(BaseModel):
17
+ image: str # base64 or URL
18
+ prompt: str
19
+ negative_prompt: str = "ugly, static, blurry, low quality"
20
+ num_frames: int = 93
21
+ num_inference_steps: int = 35
22
+ guidance_scale: float = 7.0
23
+ seed: Optional[int] = None
24
+
25
+ @app.on_event("startup")
26
+ async def load_model():
27
+ global pipe, export_to_video
28
+ from diffusers import Cosmos2VideoToWorldPipeline
29
+ from diffusers.utils import export_to_video as etv
30
+
31
+ export_to_video = etv
32
+ model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
33
+
34
+ print("Loading model...")
35
+ pipe = Cosmos2VideoToWorldPipeline.from_pretrained(
36
+ model_id,
37
+ torch_dtype=torch.bfloat16,
38
+ token=os.environ.get("HF_TOKEN"),
39
+ )
40
+ pipe.to("cuda")
41
+ print("Model loaded successfully!")
42
+
43
+ @app.post("/predict")
44
+ @app.post("/")
45
+ async def predict(request: dict):
46
+ global pipe, export_to_video
47
+
48
+ # Handle both direct and nested input formats
49
+ inputs = request.get("inputs", request)
50
+
51
+ image_data = inputs.get("image")
52
+ if not image_data:
53
+ raise HTTPException(status_code=400, detail="No image provided")
54
+
55
+ prompt = inputs.get("prompt", "")
56
+ if not prompt:
57
+ raise HTTPException(status_code=400, detail="No prompt provided")
58
+
59
+ # Load image
60
+ try:
61
+ if image_data.startswith("http"):
62
+ from diffusers.utils import load_image
63
+ image = load_image(image_data)
64
+ else:
65
+ image_bytes = base64.b64decode(image_data)
66
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
67
+
68
+ # Resize to expected dimensions for Cosmos Video2World
69
+ image = image.resize((1280, 704))
70
+
71
+ except Exception as e:
72
+ raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")
73
+
74
+ negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality")
75
+ num_frames = inputs.get("num_frames", 93)
76
+ num_inference_steps = inputs.get("num_inference_steps", 35)
77
+ guidance_scale = inputs.get("guidance_scale", 7.0)
78
+ seed = inputs.get("seed")
79
+
80
+ # Create generator on correct device
81
+ generator = None
82
+ if seed is not None:
83
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
84
+
85
+ try:
86
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
87
+ output = pipe(
88
+ image=image,
89
+ prompt=prompt,
90
+ negative_prompt=negative_prompt,
91
+ num_frames=num_frames,
92
+ num_inference_steps=num_inference_steps,
93
+ guidance_scale=guidance_scale,
94
+ generator=generator,
95
+ )
96
+
97
+ video_path = "/tmp/output.mp4"
98
+ export_to_video(output.frames[0], video_path, fps=16)
99
+
100
+ with open(video_path, "rb") as f:
101
+ video_b64 = base64.b64encode(f.read()).decode("utf-8")
102
+
103
+ return {"video": video_b64, "content_type": "video/mp4"}
104
+
105
+ except Exception as e:
106
+ import traceback
107
+ traceback.print_exc()
108
+ raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
109
+
110
+ @app.get("/health")
111
+ @app.get("/")
112
+ async def health():
113
+ return {"status": "healthy", "message": "Cosmos-Predict2 Video2World API"}