samwell commited on
Commit
850ecac
·
verified ·
1 Parent(s): 0c995b9

Explicitly move all pipeline components to CUDA, use inference_mode

Browse files
Files changed (1) hide show
  1. handler.py +26 -16
handler.py CHANGED
@@ -12,6 +12,7 @@ app = FastAPI()
12
  # Global pipeline
13
  pipe = None
14
  export_to_video = None
 
15
 
16
  class InferenceRequest(BaseModel):
17
  image: str # base64 or URL
@@ -24,7 +25,7 @@ class InferenceRequest(BaseModel):
24
 
25
  @app.on_event("startup")
26
  async def load_model():
27
- global pipe, export_to_video
28
  from diffusers import Cosmos2VideoToWorldPipeline
29
  from diffusers.utils import export_to_video as etv
30
 
@@ -37,13 +38,23 @@ async def load_model():
37
  torch_dtype=torch.bfloat16,
38
  token=os.environ.get("HF_TOKEN"),
39
  )
40
- pipe.to("cuda")
 
 
 
 
 
 
 
 
 
41
  print("Model loaded successfully!")
 
42
 
43
  @app.post("/predict")
44
  @app.post("/")
45
  async def predict(request: dict):
46
- global pipe, export_to_video
47
 
48
  # Handle both direct and nested input formats
49
  inputs = request.get("inputs", request)
@@ -56,14 +67,12 @@ async def predict(request: dict):
56
  if not prompt:
57
  raise HTTPException(status_code=400, detail="No prompt provided")
58
 
59
- # Load image using diffusers' load_image for consistent preprocessing
60
  try:
61
- from diffusers.utils import load_image
62
-
63
  if image_data.startswith("http"):
 
64
  image = load_image(image_data)
65
  else:
66
- # Save base64 to temp file and load with load_image for consistent handling
67
  image_bytes = base64.b64decode(image_data)
68
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
69
 
@@ -79,15 +88,16 @@ async def predict(request: dict):
79
  guidance_scale = inputs.get("guidance_scale", 7.0)
80
 
81
  try:
82
- # Run inference WITHOUT generator to avoid device mismatch
83
- output = pipe(
84
- image=image,
85
- prompt=prompt,
86
- negative_prompt=negative_prompt,
87
- num_frames=num_frames,
88
- num_inference_steps=num_inference_steps,
89
- guidance_scale=guidance_scale,
90
- )
 
91
 
92
  video_path = "/tmp/output.mp4"
93
  export_to_video(output.frames[0], video_path, fps=16)
 
12
  # Global pipeline
13
  pipe = None
14
  export_to_video = None
15
+ DEVICE = "cuda"
16
 
17
  class InferenceRequest(BaseModel):
18
  image: str # base64 or URL
 
25
 
26
  @app.on_event("startup")
27
  async def load_model():
28
+ global pipe, export_to_video, DEVICE
29
  from diffusers import Cosmos2VideoToWorldPipeline
30
  from diffusers.utils import export_to_video as etv
31
 
 
38
  torch_dtype=torch.bfloat16,
39
  token=os.environ.get("HF_TOKEN"),
40
  )
41
+ pipe = pipe.to(DEVICE)
42
+
43
+ # Ensure all components are on the same device
44
+ if hasattr(pipe, 'text_encoder') and pipe.text_encoder is not None:
45
+ pipe.text_encoder = pipe.text_encoder.to(DEVICE)
46
+ if hasattr(pipe, 'vae') and pipe.vae is not None:
47
+ pipe.vae = pipe.vae.to(DEVICE)
48
+ if hasattr(pipe, 'transformer') and pipe.transformer is not None:
49
+ pipe.transformer = pipe.transformer.to(DEVICE)
50
+
51
  print("Model loaded successfully!")
52
+ print(f"Pipeline device: {pipe.device}")
53
 
54
  @app.post("/predict")
55
  @app.post("/")
56
  async def predict(request: dict):
57
+ global pipe, export_to_video, DEVICE
58
 
59
  # Handle both direct and nested input formats
60
  inputs = request.get("inputs", request)
 
67
  if not prompt:
68
  raise HTTPException(status_code=400, detail="No prompt provided")
69
 
70
+ # Load image
71
  try:
 
 
72
  if image_data.startswith("http"):
73
+ from diffusers.utils import load_image
74
  image = load_image(image_data)
75
  else:
 
76
  image_bytes = base64.b64decode(image_data)
77
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
78
 
 
88
  guidance_scale = inputs.get("guidance_scale", 7.0)
89
 
90
  try:
91
+ # Run without generator - let pipeline handle device placement
92
+ with torch.inference_mode():
93
+ output = pipe(
94
+ image=image,
95
+ prompt=prompt,
96
+ negative_prompt=negative_prompt,
97
+ num_frames=num_frames,
98
+ num_inference_steps=num_inference_steps,
99
+ guidance_scale=guidance_scale,
100
+ )
101
 
102
  video_path = "/tmp/output.mp4"
103
  export_to_video(output.frames[0], video_path, fps=16)