samwell commited on
Commit
0c995b9
·
verified ·
1 Parent(s): 0e4048f

Remove generator entirely to fix CUDA/CPU device mismatch

Browse files
Files changed (1) hide show
  1. handler.py +6 -8
handler.py CHANGED
@@ -56,17 +56,19 @@ async def predict(request: dict):
56
  if not prompt:
57
  raise HTTPException(status_code=400, detail="No prompt provided")
58
 
59
- # Load image
60
  try:
 
 
61
  if image_data.startswith("http"):
62
- from diffusers.utils import load_image
63
  image = load_image(image_data)
64
  else:
 
65
  image_bytes = base64.b64decode(image_data)
66
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
67
 
68
  # Resize to expected dimensions for Cosmos Video2World (720P model)
69
- image = image.resize((1280, 704))
70
 
71
  except Exception as e:
72
  raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")
@@ -75,12 +77,9 @@ async def predict(request: dict):
75
  num_frames = inputs.get("num_frames", 93)
76
  num_inference_steps = inputs.get("num_inference_steps", 35)
77
  guidance_scale = inputs.get("guidance_scale", 7.0)
78
- seed = inputs.get("seed", 42)
79
-
80
- # Generator WITHOUT device specification (let diffusers handle it)
81
- generator = torch.Generator().manual_seed(int(seed))
82
 
83
  try:
 
84
  output = pipe(
85
  image=image,
86
  prompt=prompt,
@@ -88,7 +87,6 @@ async def predict(request: dict):
88
  num_frames=num_frames,
89
  num_inference_steps=num_inference_steps,
90
  guidance_scale=guidance_scale,
91
- generator=generator,
92
  )
93
 
94
  video_path = "/tmp/output.mp4"
 
56
  if not prompt:
57
  raise HTTPException(status_code=400, detail="No prompt provided")
58
 
59
+ # Load image using diffusers' load_image for consistent preprocessing
60
  try:
61
+ from diffusers.utils import load_image
62
+
63
  if image_data.startswith("http"):
 
64
  image = load_image(image_data)
65
  else:
66
+ # Save base64 to temp file and load with load_image for consistent handling
67
  image_bytes = base64.b64decode(image_data)
68
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
69
 
70
  # Resize to expected dimensions for Cosmos Video2World (720P model)
71
+ image = image.resize((1280, 704), Image.Resampling.LANCZOS)
72
 
73
  except Exception as e:
74
  raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")
 
77
  num_frames = inputs.get("num_frames", 93)
78
  num_inference_steps = inputs.get("num_inference_steps", 35)
79
  guidance_scale = inputs.get("guidance_scale", 7.0)
 
 
 
 
80
 
81
  try:
82
+ # Run inference WITHOUT generator to avoid device mismatch
83
  output = pipe(
84
  image=image,
85
  prompt=prompt,
 
87
  num_frames=num_frames,
88
  num_inference_steps=num_inference_steps,
89
  guidance_scale=guidance_scale,
 
90
  )
91
 
92
  video_path = "/tmp/output.mp4"