samwell commited on
Commit
d95102e
·
verified ·
1 Parent(s): 850ecac

Fix: ensure_on_device before each inference to prevent text_encoder drift

Browse files
Files changed (1) hide show
  1. handler.py +18 -10
handler.py CHANGED
@@ -6,6 +6,7 @@ from typing import Optional
6
  from PIL import Image
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
 
9
 
10
  app = FastAPI()
11
 
@@ -39,18 +40,19 @@ async def load_model():
39
  token=os.environ.get("HF_TOKEN"),
40
  )
41
  pipe = pipe.to(DEVICE)
42
-
43
- # Ensure all components are on the same device
44
- if hasattr(pipe, 'text_encoder') and pipe.text_encoder is not None:
45
- pipe.text_encoder = pipe.text_encoder.to(DEVICE)
46
- if hasattr(pipe, 'vae') and pipe.vae is not None:
47
- pipe.vae = pipe.vae.to(DEVICE)
48
- if hasattr(pipe, 'transformer') and pipe.transformer is not None:
49
- pipe.transformer = pipe.transformer.to(DEVICE)
50
-
51
  print("Model loaded successfully!")
52
  print(f"Pipeline device: {pipe.device}")
53
 
 
 
 
 
 
 
 
 
 
 
54
  @app.post("/predict")
55
  @app.post("/")
56
  async def predict(request: dict):
@@ -88,7 +90,9 @@ async def predict(request: dict):
88
  guidance_scale = inputs.get("guidance_scale", 7.0)
89
 
90
  try:
91
- # Run without generator - let pipeline handle device placement
 
 
92
  with torch.inference_mode():
93
  output = pipe(
94
  image=image,
@@ -105,6 +109,10 @@ async def predict(request: dict):
105
  with open(video_path, "rb") as f:
106
  video_b64 = base64.b64encode(f.read()).decode("utf-8")
107
 
 
 
 
 
108
  return {"video": video_b64, "content_type": "video/mp4"}
109
 
110
  except Exception as e:
 
6
  from PIL import Image
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
+ import gc
10
 
11
  app = FastAPI()
12
 
 
40
  token=os.environ.get("HF_TOKEN"),
41
  )
42
  pipe = pipe.to(DEVICE)
 
 
 
 
 
 
 
 
 
43
  print("Model loaded successfully!")
44
  print(f"Pipeline device: {pipe.device}")
45
 
46
+ def ensure_on_device():
47
+ """Ensure all pipeline components are on CUDA before inference"""
48
+ global pipe, DEVICE
49
+ pipe = pipe.to(DEVICE)
50
+ # Force text_encoder to CUDA (this is the problematic component)
51
+ if hasattr(pipe, 'text_encoder') and pipe.text_encoder is not None:
52
+ pipe.text_encoder = pipe.text_encoder.to(DEVICE)
53
+ torch.cuda.empty_cache()
54
+ gc.collect()
55
+
56
  @app.post("/predict")
57
  @app.post("/")
58
  async def predict(request: dict):
 
90
  guidance_scale = inputs.get("guidance_scale", 7.0)
91
 
92
  try:
93
+ # Ensure all components on CUDA before each inference
94
+ ensure_on_device()
95
+
96
  with torch.inference_mode():
97
  output = pipe(
98
  image=image,
 
109
  with open(video_path, "rb") as f:
110
  video_b64 = base64.b64encode(f.read()).decode("utf-8")
111
 
112
+ # Clean up after inference
113
+ torch.cuda.empty_cache()
114
+ gc.collect()
115
+
116
  return {"video": video_b64, "content_type": "video/mp4"}
117
 
118
  except Exception as e: