Gjm1234 commited on
Commit
5bda45d
·
verified ·
1 Parent(s): a964870

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -11
app.py CHANGED
@@ -10,21 +10,18 @@ import imageio
10
  import torch
11
  from custom_wan_pipeline import WanImageToVideoPipeline
12
 
 
13
  app = FastAPI()
14
 
15
- # ===== LOAD YOUR MODEL =====
16
  print("Loading WAN I2V model...")
17
  pipe = WanImageToVideoPipeline.from_pretrained(
18
- ".", # model files live in the repo root
19
  torch_dtype=torch.float16
20
  ).to("cuda")
21
-
22
  pipe.enable_xformers_memory_efficient_attention()
23
  print("Model loaded.")
24
 
25
 
26
- # ===== Helpers =====
27
-
28
  def decode_image(b64_string):
29
  image_bytes = base64.b64decode(b64_string)
30
  return Image.open(io.BytesIO(image_bytes)).convert("RGB")
@@ -42,32 +39,28 @@ def frames_to_base64_mp4(frames):
42
  return base64.b64encode(video_bytes).decode()
43
 
44
 
45
- # ===== ROUTES =====
46
-
47
  @app.post("/video")
48
  async def video_route(body: dict):
49
  try:
50
  image_b64 = body["image"]
51
  prompt = body.get("prompt", "")
52
 
53
- # Decode image
54
  image = decode_image(image_b64)
55
 
56
- # Run WAN model
57
  output = pipe(image=image, prompt=prompt)
58
  frames = output.frames
59
 
60
- # Convert frames → mp4 → base64
61
  video_b64 = frames_to_base64_mp4(frames)
62
 
63
  return JSONResponse({"video": video_b64})
 
64
  except Exception as e:
65
  return JSONResponse({"error": str(e)}, status_code=500)
66
 
67
 
68
- # HF Spaces launch
69
  def start():
70
  uvicorn.run(app, host="0.0.0.0", port=7860)
71
 
 
72
  if __name__ == "__main__":
73
  start()
 
10
  import torch
11
  from custom_wan_pipeline import WanImageToVideoPipeline
12
 
13
+
14
  app = FastAPI()
15
 
 
16
  print("Loading WAN I2V model...")
17
  pipe = WanImageToVideoPipeline.from_pretrained(
18
+ ".",
19
  torch_dtype=torch.float16
20
  ).to("cuda")
 
21
  pipe.enable_xformers_memory_efficient_attention()
22
  print("Model loaded.")
23
 
24
 
 
 
25
  def decode_image(b64_string):
26
  image_bytes = base64.b64decode(b64_string)
27
  return Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
39
  return base64.b64encode(video_bytes).decode()
40
 
41
 
 
 
42
  @app.post("/video")
43
  async def video_route(body: dict):
44
  try:
45
  image_b64 = body["image"]
46
  prompt = body.get("prompt", "")
47
 
 
48
  image = decode_image(image_b64)
49
 
 
50
  output = pipe(image=image, prompt=prompt)
51
  frames = output.frames
52
 
 
53
  video_b64 = frames_to_base64_mp4(frames)
54
 
55
  return JSONResponse({"video": video_b64})
56
+
57
  except Exception as e:
58
  return JSONResponse({"error": str(e)}, status_code=500)
59
 
60
 
 
61
  def start():
62
  uvicorn.run(app, host="0.0.0.0", port=7860)
63
 
64
+
65
  if __name__ == "__main__":
66
  start()