baenacoco commited on
Commit
2bfc401
·
verified ·
1 Parent(s): ffa1e8a

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +6 -7
  2. app.py +323 -0
  3. hub_utils.py +64 -0
  4. packages.txt +2 -0
  5. requirements.txt +13 -0
README.md CHANGED
@@ -1,12 +1,11 @@
1
  ---
2
- title: Talking Head Lora Train
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Talking Head - LoRA Train
3
+ emoji: 🎨
4
+ colorFrom: yellow
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
+ hardware: a100-large
11
  ---
 
 
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Space 4: Train LoRA (Flux.1-dev + PEFT)
2
+
3
+ Downloads frames from Hub -> LoRA training on Flux.1 -> saves adapter to Hub.
4
+ GPU: A100 (Flux.1-dev full pipeline + LoRA)
5
+ """
6
+ import gc
7
+ import json
8
+ import logging
9
+ import os
10
+ import shutil
11
+ import traceback
12
+ from pathlib import Path
13
+
14
+ import gradio as gr
15
+ import torch
16
+ from PIL import Image
17
+ from torchvision import transforms
18
+
19
+ from hub_utils import download_step, upload_step
20
+
21
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # ── Config ──
25
+ IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
26
+ _data_path = Path("/data")
27
+ if IS_HF_SPACE and _data_path.exists() and os.access(_data_path, os.W_OK):
28
+ BASE_DIR = _data_path
29
+ else:
30
+ BASE_DIR = Path("data")
31
+
32
+ FRAMES_DIR = BASE_DIR / "frames"
33
+ LORA_MODEL_DIR = BASE_DIR / "lora_model"
34
+ TEMP_DIR = BASE_DIR / "temp"
35
+ HF_CACHE_DIR = BASE_DIR / "hf_cache"
36
+
37
+ for d in [FRAMES_DIR, LORA_MODEL_DIR, TEMP_DIR, HF_CACHE_DIR]:
38
+ d.mkdir(parents=True, exist_ok=True)
39
+
40
+ os.environ["HF_HOME"] = str(HF_CACHE_DIR)
41
+ os.environ["TRANSFORMERS_CACHE"] = str(HF_CACHE_DIR)
42
+
43
+ FLUX_MODEL_ID = "black-forest-labs/FLUX.1-dev"
44
+ LORA_TRIGGER_WORD = "alvaro_person"
45
+ LORA_RANK = 16
46
+ LORA_ALPHA = 16
47
+ LORA_LR = 1e-4
48
+ LORA_STEPS = 1500
49
+ LORA_BATCH_SIZE = 1
50
+ LORA_RESOLUTION = 1024
51
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
52
+
53
+ APP_VERSION = "1.0.0"
54
+
55
+
56
+ def _clear_cache():
57
+ gc.collect()
58
+ if torch.cuda.is_available():
59
+ torch.cuda.empty_cache()
60
+ torch.cuda.synchronize()
61
+
62
+
63
+ # ── Dataset preparation ──
64
+
65
+ def _prepare_dataset(image_dir, trigger_word):
66
+ dataset_dir = TEMP_DIR / "lora_dataset"
67
+ if dataset_dir.exists():
68
+ shutil.rmtree(dataset_dir)
69
+ dataset_dir.mkdir(parents=True)
70
+
71
+ images = sorted(image_dir.glob("*.jpg")) + sorted(image_dir.glob("*.png"))
72
+ captions_file = image_dir / "captions.json"
73
+ captions = {}
74
+ if captions_file.exists():
75
+ with open(captions_file) as f:
76
+ captions = json.load(f)
77
+
78
+ for i, img_path in enumerate(images):
79
+ dst_img = dataset_dir / f"img_{i:04d}{img_path.suffix}"
80
+ shutil.copy2(img_path, dst_img)
81
+ caption = captions.get(img_path.name, "a photo of a person")
82
+ full_caption = f"{trigger_word}, {caption}"
83
+ dst_img.with_suffix(".txt").write_text(full_caption)
84
+
85
+ logger.info(f"Prepared {len(images)} images with trigger word '{trigger_word}'")
86
+ return dataset_dir
87
+
88
+
89
+ # ── LoRA training ──
90
+
91
+ def _train_lora(
92
+ dataset_dir, output_dir, rank, alpha, learning_rate,
93
+ max_steps, batch_size, resolution, progress_callback=None,
94
+ ):
95
+ from diffusers import FluxPipeline
96
+ from peft import LoraConfig, get_peft_model
97
+ from torch.utils.data import Dataset, DataLoader
98
+
99
+ class CaptionedImageDataset(Dataset):
100
+ def __init__(self, root_dir, res):
101
+ self.root = Path(root_dir)
102
+ self.images = sorted(self.root.glob("*.jpg")) + sorted(self.root.glob("*.png"))
103
+ self.transform = transforms.Compose([
104
+ transforms.Resize((res, res)),
105
+ transforms.ToTensor(),
106
+ transforms.Normalize([0.5], [0.5]),
107
+ ])
108
+ def __len__(self):
109
+ return len(self.images)
110
+ def __getitem__(self, idx):
111
+ img_path = self.images[idx]
112
+ image = Image.open(img_path).convert("RGB")
113
+ image = self.transform(image)
114
+ txt_path = img_path.with_suffix(".txt")
115
+ caption = txt_path.read_text().strip() if txt_path.exists() else ""
116
+ return {"image": image, "caption": caption}
117
+
118
+ logger.info(f"Loading Flux.1 from {FLUX_MODEL_ID}...")
119
+ pipe = FluxPipeline.from_pretrained(
120
+ FLUX_MODEL_ID, torch_dtype=torch.bfloat16,
121
+ token=os.environ.get("HF_TOKEN"),
122
+ )
123
+
124
+ lora_config = LoraConfig(
125
+ r=rank, lora_alpha=alpha,
126
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
127
+ lora_dropout=0.0,
128
+ )
129
+
130
+ pipe.transformer = get_peft_model(pipe.transformer, lora_config)
131
+ pipe.transformer.to(DEVICE, dtype=torch.bfloat16)
132
+ pipe.transformer.train()
133
+
134
+ trainable_params = sum(p.numel() for p in pipe.transformer.parameters() if p.requires_grad)
135
+ logger.info(f"Trainable LoRA parameters: {trainable_params:,}")
136
+
137
+ dataset = CaptionedImageDataset(dataset_dir, resolution)
138
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
139
+
140
+ optimizer = torch.optim.AdamW(
141
+ [p for p in pipe.transformer.parameters() if p.requires_grad],
142
+ lr=learning_rate, weight_decay=0.01,
143
+ )
144
+
145
+ pipe.text_encoder.to(DEVICE, dtype=torch.bfloat16)
146
+ if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 is not None:
147
+ pipe.text_encoder_2.to(DEVICE, dtype=torch.bfloat16)
148
+ pipe.vae.to(DEVICE, dtype=torch.bfloat16)
149
+
150
+ global_step = 0
151
+ for epoch in range(max_steps // max(1, len(dataset)) + 1):
152
+ for batch in loader:
153
+ if global_step >= max_steps:
154
+ break
155
+ images_batch = batch["image"].to(DEVICE, dtype=torch.bfloat16)
156
+ captions_batch = batch["caption"]
157
+
158
+ with torch.no_grad():
159
+ latents = pipe.vae.encode(images_batch).latent_dist.sample()
160
+ latents = (latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
161
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
162
+ prompt=captions_batch, prompt_2=captions_batch,
163
+ )
164
+ prompt_embeds = prompt_embeds.to(dtype=torch.bfloat16)
165
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=torch.bfloat16)
166
+
167
+ noise = torch.randn_like(latents)
168
+ timesteps = torch.randint(0, 1000, (latents.shape[0],), device=latents.device)
169
+ sigmas = (timesteps.float() / 1000.0).to(dtype=torch.bfloat16).view(-1, 1, 1, 1)
170
+ noisy_latents = (1 - sigmas) * latents + sigmas * noise
171
+
172
+ bs, ch, h, w = noisy_latents.shape
173
+ noisy_packed = pipe._pack_latents(noisy_latents, bs, ch, h, w)
174
+ latent_image_ids = pipe._prepare_latent_image_ids(bs, h // 2, w // 2, DEVICE, torch.bfloat16)
175
+ guidance = torch.full((bs,), 3.5, device=DEVICE, dtype=torch.bfloat16)
176
+
177
+ noise_pred = pipe.transformer(
178
+ hidden_states=noisy_packed, timestep=timesteps, guidance=guidance,
179
+ encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds,
180
+ txt_ids=text_ids, img_ids=latent_image_ids, return_dict=False,
181
+ )[0]
182
+
183
+ target = noise - latents
184
+ target_packed = pipe._pack_latents(target, bs, ch, h, w)
185
+ loss = torch.nn.functional.mse_loss(noise_pred, target_packed)
186
+
187
+ loss.backward()
188
+ optimizer.step()
189
+ optimizer.zero_grad()
190
+
191
+ global_step += 1
192
+ if global_step % 50 == 0:
193
+ logger.info(f"Step {global_step}/{max_steps}, Loss: {loss.item():.4f}")
194
+ if progress_callback:
195
+ prog = 0.1 + (global_step / max_steps) * 0.85
196
+ progress_callback(prog, f"Step {global_step}/{max_steps}, Loss: {loss.item():.4f}")
197
+
198
+ if global_step >= max_steps:
199
+ break
200
+
201
+ pipe.transformer.save_pretrained(str(output_dir))
202
+ logger.info(f"LoRA saved to {output_dir}")
203
+ del pipe
204
+ _clear_cache()
205
+
206
+
207
+ # ── Gradio handlers ──
208
+
209
+ def download_frames_from_hub(project_name):
210
+ if not project_name or not project_name.strip():
211
+ return "Error: Debes introducir un nombre de proyecto"
212
+ name = project_name.strip()
213
+ try:
214
+ if FRAMES_DIR.exists():
215
+ shutil.rmtree(FRAMES_DIR)
216
+ FRAMES_DIR.mkdir(parents=True)
217
+
218
+ download_step(name, "step1_frames", str(BASE_DIR))
219
+ src = BASE_DIR / name / "step1_frames"
220
+ if src.exists():
221
+ for f in src.iterdir():
222
+ shutil.move(str(f), str(FRAMES_DIR / f.name))
223
+ shutil.rmtree(BASE_DIR / name, ignore_errors=True)
224
+
225
+ frames = sorted(FRAMES_DIR.glob("*.jpg"))
226
+ return f"OK - Descargados {len(frames)} frames"
227
+ except Exception as e:
228
+ return f"Error: {e}"
229
+
230
+
231
+ def train_lora_handler(project_name, trigger_word, rank, lr, steps, progress=gr.Progress()):
232
+ if not project_name or not project_name.strip():
233
+ return "Error: Debes introducir un nombre de proyecto"
234
+
235
+ images = list(FRAMES_DIR.glob("*.jpg")) + list(FRAMES_DIR.glob("*.png"))
236
+ if len(images) < 10:
237
+ return f"Error: Se necesitan al menos 10 imagenes, encontradas {len(images)}. Descarga frames primero."
238
+
239
+ logger.info(f"=== LoRA Training Started === trigger={trigger_word}, rank={rank}, steps={steps}")
240
+ try:
241
+ _clear_cache()
242
+ progress(0.05, desc="Preparando dataset...")
243
+ dataset_dir = _prepare_dataset(FRAMES_DIR, trigger_word)
244
+
245
+ progress(0.1, desc="Iniciando entrenamiento LoRA...")
246
+ _train_lora(
247
+ dataset_dir=dataset_dir, output_dir=LORA_MODEL_DIR,
248
+ rank=int(rank), alpha=int(rank), learning_rate=lr,
249
+ max_steps=int(steps), batch_size=LORA_BATCH_SIZE,
250
+ resolution=LORA_RESOLUTION,
251
+ progress_callback=lambda p, m: progress(p, desc=m),
252
+ )
253
+
254
+ config = {
255
+ "base_model": FLUX_MODEL_ID, "trigger_word": trigger_word,
256
+ "rank": int(rank), "alpha": int(rank), "steps": int(steps),
257
+ "resolution": LORA_RESOLUTION,
258
+ }
259
+ with open(LORA_MODEL_DIR / "lora_config.json", "w") as f:
260
+ json.dump(config, f, indent=2)
261
+
262
+ shutil.rmtree(dataset_dir, ignore_errors=True)
263
+ _clear_cache()
264
+
265
+ logger.info("=== LoRA Training Complete ===")
266
+ return f"OK - LoRA guardado en: {LORA_MODEL_DIR}"
267
+ except Exception as e:
268
+ logger.error(f"=== LoRA Training Failed ===\n{traceback.format_exc()}")
269
+ return f"Error: {e}"
270
+
271
+
272
+ def save_to_hub(project_name):
273
+ if not project_name or not project_name.strip():
274
+ return "Error: Debes introducir un nombre de proyecto"
275
+ name = project_name.strip()
276
+ models = list(LORA_MODEL_DIR.glob("*.safetensors")) + list(LORA_MODEL_DIR.glob("adapter_model.*"))
277
+ if not models:
278
+ return "Error: No hay modelo LoRA para guardar. Entrena primero."
279
+ try:
280
+ return upload_step(name, "step4_lora", str(LORA_MODEL_DIR))
281
+ except Exception as e:
282
+ return f"Error: {e}"
283
+
284
+
285
+ # ── UI ──
286
+
287
+ with gr.Blocks(title="Talking Head - LoRA Train", theme=gr.themes.Soft()) as demo:
288
+ gr.Markdown(f"# Talking Head - Entrenar LoRA `v{APP_VERSION}`\nFlux.1-dev LoRA training con tus imagenes")
289
+
290
+ project_name = gr.Textbox(
291
+ label="Nombre del proyecto",
292
+ placeholder="mi_proyecto",
293
+ info="Obligatorio. Se usa como carpeta en el Hub.",
294
+ )
295
+
296
+ gr.Markdown("### 1. Descargar frames del Hub")
297
+ download_btn = gr.Button("Descargar frames del Hub", variant="secondary")
298
+ download_status = gr.Textbox(label="Estado descarga", interactive=False)
299
+
300
+ gr.Markdown("### 2. Entrenar LoRA")
301
+ with gr.Row():
302
+ trigger_word = gr.Textbox(value=LORA_TRIGGER_WORD, label="Trigger Word")
303
+ lora_rank = gr.Slider(4, 64, value=LORA_RANK, step=4, label="LoRA Rank")
304
+ with gr.Row():
305
+ lora_lr = gr.Number(value=LORA_LR, label="Learning Rate")
306
+ lora_steps = gr.Slider(500, 5000, value=LORA_STEPS, step=100, label="Training Steps")
307
+ train_btn = gr.Button("Entrenar LoRA", variant="primary")
308
+ train_status = gr.Textbox(label="Estado entrenamiento", interactive=False)
309
+
310
+ gr.Markdown("### 3. Guardar modelo en Hub")
311
+ save_btn = gr.Button("Guardar en Hub", variant="secondary")
312
+ save_status = gr.Textbox(label="Estado guardado", interactive=False)
313
+
314
+ download_btn.click(download_frames_from_hub, inputs=[project_name], outputs=[download_status])
315
+ train_btn.click(
316
+ train_lora_handler,
317
+ inputs=[project_name, trigger_word, lora_rank, lora_lr, lora_steps],
318
+ outputs=[train_status],
319
+ )
320
+ save_btn.click(save_to_hub, inputs=[project_name], outputs=[save_status])
321
+
322
+ if __name__ == "__main__":
323
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
hub_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub utilities for uploading/downloading step data to HF Dataset repo."""
2
+ import os
3
+ import logging
4
+ from pathlib import Path
5
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_tree
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ HF_DATASET_REPO_ID = "baenacoco/talking-head-avatar"
10
+
11
+
12
+ def _get_api():
13
+ token = os.environ.get("HF_TOKEN")
14
+ if not token:
15
+ raise ValueError("HF_TOKEN no encontrado en variables de entorno")
16
+ api = HfApi(token=token)
17
+ api.create_repo(repo_id=HF_DATASET_REPO_ID, repo_type="dataset", exist_ok=True)
18
+ return api
19
+
20
+
21
+ def upload_step(name: str, step_folder: str, local_dir: str):
22
+ """Upload a local directory to {name}/{step_folder}/ in the dataset repo."""
23
+ api = _get_api()
24
+ api.upload_folder(
25
+ folder_path=local_dir,
26
+ path_in_repo=f"{name}/{step_folder}",
27
+ repo_id=HF_DATASET_REPO_ID,
28
+ repo_type="dataset",
29
+ )
30
+ logger.info(f"Uploaded {local_dir} -> {name}/{step_folder}")
31
+ return f"Subido a Hub: {name}/{step_folder}"
32
+
33
+
34
+ def download_step(name: str, step_folder: str, local_dir: str):
35
+ """Download {name}/{step_folder}/ from the dataset repo to a local directory."""
36
+ from huggingface_hub import snapshot_download
37
+ token = os.environ.get("HF_TOKEN")
38
+ snapshot_download(
39
+ repo_id=HF_DATASET_REPO_ID,
40
+ repo_type="dataset",
41
+ local_dir=local_dir,
42
+ allow_patterns=[f"{name}/{step_folder}/**"],
43
+ token=token,
44
+ )
45
+ logger.info(f"Downloaded {name}/{step_folder} -> {local_dir}")
46
+ return f"Descargado de Hub: {name}/{step_folder}"
47
+
48
+
49
+ def list_projects() -> list[str]:
50
+ """List project names (top-level folders) in the dataset repo."""
51
+ token = os.environ.get("HF_TOKEN")
52
+ try:
53
+ api = HfApi(token=token)
54
+ entries = list(api.list_repo_tree(
55
+ repo_id=HF_DATASET_REPO_ID, repo_type="dataset", path_in_repo="",
56
+ ))
57
+ return sorted(set(
58
+ e.rfilename.split("/")[0] if hasattr(e, "rfilename") else e.path.split("/")[0]
59
+ for e in entries
60
+ if ("/" in getattr(e, "rfilename", "")) or hasattr(e, "path")
61
+ ))
62
+ except Exception as e:
63
+ logger.warning(f"Could not list projects: {e}")
64
+ return []
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libgl1-mesa-glx
2
+ libglib2.0-0
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ setuptools>=69.0.0
2
+ gradio>=5.9.1
3
+ torch>=2.1.0
4
+ torchvision>=0.16.0
5
+ transformers>=4.36.0,<5.0.0
6
+ diffusers>=0.25.0
7
+ accelerate>=0.25.0
8
+ safetensors>=0.4.0
9
+ peft>=0.7.0
10
+ huggingface_hub>=0.20.0
11
+ Pillow>=10.0.0
12
+ sentencepiece>=0.1.99
13
+ protobuf>=3.20.0