ArmanRV commited on
Commit
e337195
·
verified ·
1 Parent(s): 14da46e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -4,11 +4,13 @@ import gradio as gr
4
  import torch
5
  from PIL import Image
6
 
 
7
  from diffusers import StableVideoDiffusionPipeline
8
  import imageio.v2 as imageio
9
 
 
10
 
11
- # -------- paths --------
12
  ROOT = "/data" if os.path.isdir("/data") else "/home/user"
13
  MODEL_DIR = os.path.join(ROOT, "models", "svd-xt")
14
  OUT_DIR = os.path.join(ROOT, "outputs")
@@ -16,19 +18,39 @@ os.makedirs(OUT_DIR, exist_ok=True)
16
 
17
  pipe = None
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def get_pipe():
20
  global pipe
21
  if pipe is not None:
22
  return pipe
23
 
24
- if not os.path.isdir(MODEL_DIR):
25
- raise gr.Error(f"Model not found at {MODEL_DIR}. postBuild didn't download it.")
26
 
27
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
28
  pipe = StableVideoDiffusionPipeline.from_pretrained(
29
  MODEL_DIR,
30
  torch_dtype=dtype,
31
- local_files_only=True, # <-- запрет докачки в рантайме
32
  )
33
 
34
  if torch.cuda.is_available():
@@ -51,13 +73,8 @@ def run(image: Image.Image, motion: int, fps: int, frames: int, steps: int, seed
51
  generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(int(seed))
52
 
53
  pipe = get_pipe()
54
-
55
- # SVD любит 1024 ширину по умолчанию, но лучше держать умеренно для VRAM
56
- # Можно подстроить под фото, но начнём с безопасного.
57
  img = image.convert("RGB")
58
 
59
- # В diffusers для SVD параметры могут называться немного по-разному между версиями,
60
- # но обычно работают: num_frames, fps, motion_bucket_id, num_inference_steps
61
  out = pipe(
62
  image=img,
63
  num_frames=int(frames),
@@ -73,15 +90,15 @@ def run(image: Image.Image, motion: int, fps: int, frames: int, steps: int, seed
73
  return out_path
74
 
75
 
76
- with gr.Blocks(title="SVD img2vid XT (local)") as demo:
77
- gr.Markdown("## Stable Video Diffusion (img2vid-xt) local in Space")
78
 
79
  with gr.Row():
80
  inp = gr.Image(type="pil", label="Input image")
81
  out = gr.Video(label="Output video")
82
 
83
  with gr.Accordion("Settings", open=False):
84
- motion = gr.Slider(1, 255, value=127, step=1, label="motion_bucket_id (higher = more motion)")
85
  fps = gr.Slider(6, 30, value=12, step=1, label="fps")
86
  frames = gr.Slider(8, 30, value=14, step=1, label="num_frames")
87
  steps = gr.Slider(10, 50, value=25, step=1, label="steps")
 
4
  import torch
5
  from PIL import Image
6
 
7
+ from huggingface_hub import snapshot_download
8
  from diffusers import StableVideoDiffusionPipeline
9
  import imageio.v2 as imageio
10
 
11
+ REPO_ID = "stabilityai/stable-video-diffusion-img2vid-xt"
12
 
13
+ # Где хранить файлы модели (пытаемся /data, если нет — /home/user)
14
  ROOT = "/data" if os.path.isdir("/data") else "/home/user"
15
  MODEL_DIR = os.path.join(ROOT, "models", "svd-xt")
16
  OUT_DIR = os.path.join(ROOT, "outputs")
 
18
 
19
  pipe = None
20
 
21
+
22
+ def ensure_model():
23
+ os.makedirs(MODEL_DIR, exist_ok=True)
24
+
25
+ # Если уже скачано — не качаем заново
26
+ if any(os.path.exists(os.path.join(MODEL_DIR, f)) for f in ["model_index.json", "config.json"]):
27
+ print("✅ Model already present in:", MODEL_DIR)
28
+ return
29
+
30
+ print("⬇️ Downloading model to:", MODEL_DIR)
31
+ snapshot_download(
32
+ repo_id=REPO_ID,
33
+ local_dir=MODEL_DIR,
34
+ local_dir_use_symlinks=False,
35
+ resume_download=True,
36
+ max_workers=4,
37
+ token=os.environ.get("HF_TOKEN"), # если модель gated
38
+ )
39
+ print("✅ Download finished. Top files:", os.listdir(MODEL_DIR)[:30])
40
+
41
+
42
  def get_pipe():
43
  global pipe
44
  if pipe is not None:
45
  return pipe
46
 
47
+ ensure_model()
 
48
 
49
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
50
  pipe = StableVideoDiffusionPipeline.from_pretrained(
51
  MODEL_DIR,
52
  torch_dtype=dtype,
53
+ local_files_only=True, # <-- важно: после скачивания не лезем в интернет
54
  )
55
 
56
  if torch.cuda.is_available():
 
73
  generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(int(seed))
74
 
75
  pipe = get_pipe()
 
 
 
76
  img = image.convert("RGB")
77
 
 
 
78
  out = pipe(
79
  image=img,
80
  num_frames=int(frames),
 
90
  return out_path
91
 
92
 
93
+ with gr.Blocks(title="SVD img2vid XT") as demo:
94
+ gr.Markdown("## Stable Video Diffusion (img2vid-xt)\nModel downloads once at startup (resume enabled).")
95
 
96
  with gr.Row():
97
  inp = gr.Image(type="pil", label="Input image")
98
  out = gr.Video(label="Output video")
99
 
100
  with gr.Accordion("Settings", open=False):
101
+ motion = gr.Slider(1, 255, value=100, step=1, label="motion_bucket_id (lower = calmer)")
102
  fps = gr.Slider(6, 30, value=12, step=1, label="fps")
103
  frames = gr.Slider(8, 30, value=14, step=1, label="num_frames")
104
  steps = gr.Slider(10, 50, value=25, step=1, label="steps")