samwell commited on
Commit
2b2bf3c
·
verified ·
1 Parent(s): aed4085

Fix CUDA device mismatch - resize image and add autocast

Browse files
Files changed (1) hide show
  1. handler.py +17 -9
handler.py CHANGED
@@ -64,6 +64,10 @@ async def predict(request: dict):
64
  else:
65
  image_bytes = base64.b64decode(image_data)
66
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
 
 
67
  except Exception as e:
68
  raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")
69
 
@@ -73,20 +77,22 @@ async def predict(request: dict):
73
  guidance_scale = inputs.get("guidance_scale", 7.0)
74
  seed = inputs.get("seed")
75
 
 
76
  generator = None
77
  if seed is not None:
78
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
79
 
80
  try:
81
- output = pipe(
82
- image=image,
83
- prompt=prompt,
84
- negative_prompt=negative_prompt,
85
- num_frames=num_frames,
86
- num_inference_steps=num_inference_steps,
87
- guidance_scale=guidance_scale,
88
- generator=generator,
89
- )
 
90
 
91
  video_path = "/tmp/output.mp4"
92
  export_to_video(output.frames[0], video_path, fps=16)
@@ -97,6 +103,8 @@ async def predict(request: dict):
97
  return {"video": video_b64, "content_type": "video/mp4"}
98
 
99
  except Exception as e:
 
 
100
  raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
101
 
102
  @app.get("/health")
 
64
  else:
65
  image_bytes = base64.b64decode(image_data)
66
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
67
+
68
+ # Resize to expected dimensions for Cosmos Video2World
69
+ image = image.resize((1280, 704))
70
+
71
  except Exception as e:
72
  raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")
73
 
 
77
  guidance_scale = inputs.get("guidance_scale", 7.0)
78
  seed = inputs.get("seed")
79
 
80
+ # Create generator on correct device
81
  generator = None
82
  if seed is not None:
83
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
84
 
85
  try:
86
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
87
+ output = pipe(
88
+ image=image,
89
+ prompt=prompt,
90
+ negative_prompt=negative_prompt,
91
+ num_frames=num_frames,
92
+ num_inference_steps=num_inference_steps,
93
+ guidance_scale=guidance_scale,
94
+ generator=generator,
95
+ )
96
 
97
  video_path = "/tmp/output.mp4"
98
  export_to_video(output.frames[0], video_path, fps=16)
 
103
  return {"video": video_b64, "content_type": "video/mp4"}
104
 
105
  except Exception as e:
106
+ import traceback
107
+ traceback.print_exc()
108
  raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
109
 
110
  @app.get("/health")