eeshaAI commited on
Commit
6bdc1b2
Β·
verified Β·
1 Parent(s): c8810e1

Update train_on_hf_spaces.py: auto-start training, file-based logging

Browse files
Files changed (1) hide show
  1. train_on_hf_spaces.py +198 -182
train_on_hf_spaces.py CHANGED
@@ -5,13 +5,8 @@ HuggingFace Spaces Training Script for EeshaAI/zeeb
5
  Runs on HuggingFace Spaces (free CPU tier, 16GB RAM).
6
  Fine-tunes OLMo 2 1B Instruct with LoRA to generate video tokens.
7
 
8
- Steps:
9
- 1. Load OLMo 2 1B Instruct (full HuggingFace model, fp32)
10
- 2. Expand vocabulary with visual tokens (<v_0> ... <v_1023>)
11
- 3. Apply LoRA (r=4, alpha=8) to q_proj and v_proj
12
- 4. Train on tokenized video data (3 epochs)
13
- 5. Merge LoRA weights back into base model
14
- 6. Push merged model to EeshaAI/zeeb
15
  """
16
 
17
  import os
@@ -20,7 +15,7 @@ import json
20
  import time
21
  import traceback
22
  import gc
23
- from typing import Generator
24
 
25
  import torch
26
  from torch.utils.data import DataLoader, Dataset
@@ -46,6 +41,25 @@ GRADIENT_ACCUMULATION = 4
46
  MAX_GRAD_NORM = 1.0
47
  LOG_EVERY = 1
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # ---------------------------------------------------------------------------
50
  # Dataset
51
  # ---------------------------------------------------------------------------
@@ -56,7 +70,6 @@ class VideoTokenDataset(Dataset):
56
  with open(data_path) as f:
57
  self.data = json.load(f)
58
  self.max_tokens = max_tokens
59
- print(f"[Dataset] Loaded {len(self.data)} samples from {data_path}")
60
 
61
  def __len__(self):
62
  return len(self.data)
@@ -74,92 +87,79 @@ class VideoTokenDataset(Dataset):
74
 
75
 
76
  # ---------------------------------------------------------------------------
77
- # Training
78
  # ---------------------------------------------------------------------------
79
- def train(data_path: str = "tokenized_dataset.json") -> Generator[str, None, None]:
80
- """
81
- Main training loop. Yields log messages for the Gradio UI.
82
- """
83
- yield "πŸš€ Starting training pipeline...\n"
84
-
85
- # ── 1. Load tokenizer & model ──────────────────────────────────────────
86
- yield "πŸ“¦ Loading OLMo 2 1B Instruct tokenizer...\n"
87
 
88
  try:
89
- from transformers import AutoModelForCausalLM, AutoTokenizer
90
- except ImportError as e:
91
- yield f"❌ transformers not installed: {e}\n"
92
- raise
93
-
94
- try:
95
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
96
- if tokenizer.pad_token is None:
97
- tokenizer.pad_token = tokenizer.eos_token
98
- yield f"βœ… Tokenizer loaded. Vocab size: {len(tokenizer)}\n"
99
  except Exception as e:
100
- yield f"❌ Failed to load tokenizer: {e}\n"
101
- yield traceback.format_exc() + "\n"
102
- raise
103
 
104
- yield "πŸ“¦ Loading model in float32 on CPU (this takes ~2-3 min)...\n"
105
- try:
106
- model = AutoModelForCausalLM.from_pretrained(
107
- MODEL_NAME,
108
- trust_remote_code=True,
109
- torch_dtype=torch.float32,
110
- )
111
- yield f"βœ… Model loaded. Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M\n"
112
- except Exception as e:
113
- yield f"❌ Failed to load model: {e}\n"
114
- yield traceback.format_exc() + "\n"
115
- raise
116
 
117
- # ── 2. Expand vocabulary ───────────────────────────────────────────────
118
- yield f"πŸ”€ Adding {CODEBOOK_SIZE} visual tokens + special tokens...\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  visual_tokens = [VIDEO_START, VIDEO_END, VIDEO_PAD]
120
  for i in range(CODEBOOK_SIZE):
121
  visual_tokens.append(f"<v_{i}>")
122
 
123
  num_added = tokenizer.add_tokens(visual_tokens)
124
  model.resize_token_embeddings(len(tokenizer))
125
- yield f"βœ… Added {num_added} tokens. New vocab size: {len(tokenizer)}\n"
126
-
127
- # ── 3. Apply LoRA ─────────────────────────────────────────────────────
128
- yield f"πŸ”§ Applying LoRA (r={LORA_R}, alpha={LORA_ALPHA})...\n"
129
- try:
130
- from peft import LoraConfig, get_peft_model, TaskType
131
-
132
- lora_config = LoraConfig(
133
- r=LORA_R,
134
- lora_alpha=LORA_ALPHA,
135
- target_modules=["q_proj", "v_proj"],
136
- lora_dropout=LORA_DROPOUT,
137
- bias="none",
138
- task_type=TaskType.CAUSAL_LM,
139
- )
140
- model = get_peft_model(model, lora_config)
141
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
142
- total = sum(p.numel() for p in model.parameters())
143
- yield f"βœ… LoRA applied. Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)\n"
144
- except Exception as e:
145
- yield f"❌ Failed to apply LoRA: {e}\n"
146
- yield traceback.format_exc() + "\n"
147
- raise
148
-
149
- # ── 4. Load dataset ───────────────────────────────────────────────────
150
- yield f"πŸ“Š Loading dataset from {data_path}...\n"
151
- try:
152
- dataset = VideoTokenDataset(data_path, max_tokens=256)
153
- dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
154
- total_steps = NUM_EPOCHS * len(dataloader)
155
- yield f"πŸ“Š {len(dataset)} samples Γ— {NUM_EPOCHS} epochs = {total_steps} steps\n"
156
- except Exception as e:
157
- yield f"❌ Failed to load dataset: {e}\n"
158
- yield traceback.format_exc() + "\n"
159
- raise
160
-
161
- # ── 5. Train ──────────────────────────────────────────────────────────
162
- yield "πŸ”₯ Starting training loop...\n\n"
163
 
164
  optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
165
  model.train()
@@ -168,123 +168,139 @@ def train(data_path: str = "tokenized_dataset.json") -> Generator[str, None, Non
168
  running_loss = 0.0
169
  start_time = time.time()
170
 
171
- try:
172
- for epoch in range(NUM_EPOCHS):
173
- epoch_loss = 0.0
174
- num_batches = 0
175
-
176
- for batch_idx, batch in enumerate(dataloader):
177
- prompt = batch["prompt"][0]
178
- video_tokens = batch["video_tokens"][0]
179
-
180
- # Format training text
181
- token_str = " ".join(f"<v_{t.item()}>" for t in video_tokens[:64]) # limit tokens for memory
182
- text = f"Create a video of: {prompt} {VIDEO_START} {token_str} {VIDEO_END}"
183
-
184
- inputs = tokenizer(
185
- text,
186
- return_tensors="pt",
187
- truncation=True,
188
- max_length=MAX_SEQ_LEN,
189
- padding="max_length",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  )
191
 
192
- # Forward pass
193
- outputs = model(**inputs, labels=inputs["input_ids"])
194
- loss = outputs.loss / GRADIENT_ACCUMULATION
195
-
196
- # Backward pass
197
- loss.backward()
198
-
199
- if (batch_idx + 1) % GRADIENT_ACCUMULATION == 0 or (batch_idx + 1) == len(dataloader):
200
- torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
201
- optimizer.step()
202
- optimizer.zero_grad()
203
 
204
- global_step += 1
205
- batch_loss = loss.item() * GRADIENT_ACCUMULATION
206
- epoch_loss += batch_loss
207
- running_loss += batch_loss
208
- num_batches += 1
209
 
210
- elapsed = time.time() - start_time
211
- steps_per_sec = global_step / elapsed if elapsed > 0 else 0
 
212
 
213
- if batch_idx % LOG_EVERY == 0:
214
- msg = (
215
- f" Epoch {epoch + 1}/{NUM_EPOCHS} | "
216
- f"Step {batch_idx + 1}/{len(dataloader)} | "
217
- f"Loss: {batch_loss:.4f} | "
218
- f"Avg: {epoch_loss / num_batches:.4f} | "
219
- f"Speed: {steps_per_sec:.2f} steps/s\n"
220
- )
221
- yield msg
222
 
223
- # Free memory
224
- del outputs, loss
225
- gc.collect()
 
 
226
 
227
- avg_epoch_loss = epoch_loss / num_batches
228
- yield f"\nπŸ“ˆ Epoch {epoch + 1} complete. Avg Loss: {avg_epoch_loss:.4f}\n\n"
229
 
230
- except Exception as e:
231
- yield f"\n❌ Training error: {e}\n"
232
- yield traceback.format_exc() + "\n"
233
- raise
234
-
235
- total_time = time.time() - start_time
236
- yield f"βœ… Training complete in {total_time:.0f}s ({total_time / 60:.1f} min)\n"
237
- yield f" Final avg loss: {running_loss / global_step:.4f}\n\n"
238
 
239
- # ── 6. Merge & push ──────────────────────────────────────────────────
240
- yield "πŸ”€ Merging LoRA weights back into base model...\n"
241
  try:
242
- model = model.merge_and_unload()
243
- yield "βœ… LoRA merged.\n"
244
  except Exception as e:
245
- yield f"⚠️ Merge note: {e}\n"
246
 
247
- yield "πŸ’Ύ Saving model locally...\n"
248
- save_dir = "./trained_model"
249
- try:
250
- model.save_pretrained(save_dir, safe_serialization=True)
251
- tokenizer.save_pretrained(save_dir)
252
- yield "βœ… Model saved locally.\n"
253
- except Exception as e:
254
- yield f"❌ Save failed: {e}\n"
255
- yield traceback.format_exc() + "\n"
256
- raise
257
-
258
- yield f"πŸš€ Pushing to {REPO_ID}...\n"
259
- try:
260
- from huggingface_hub import HfApi
261
 
262
- api = HfApi(token=HF_TOKEN)
263
 
264
- # Create model repo if it doesn't exist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  try:
266
- api.create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
267
- except Exception as e:
268
- yield f"⚠️ Repo creation note: {e}\n"
269
-
270
- api.upload_folder(
271
- folder_path=save_dir,
272
- repo_id=REPO_ID,
273
- repo_type="model",
274
- commit_message=f"LoRA-trained OLMo 2 1B (r={LORA_R}, {NUM_EPOCHS} epochs)",
275
- )
276
- yield f"βœ… Model pushed to https://huggingface.co/{REPO_ID}\n"
277
- yield "\nπŸŽ‰ All done! The trained model is now available on HuggingFace.\n"
278
- except Exception as e:
279
- yield f"❌ Push failed: {e}\n"
280
- yield traceback.format_exc() + "\n"
281
- raise
 
 
 
282
 
283
 
284
  # ---------------------------------------------------------------------------
285
- # CLI entry point (for local testing)
286
  # ---------------------------------------------------------------------------
287
  if __name__ == "__main__":
288
  data_path = sys.argv[1] if len(sys.argv) > 1 else "tokenized_dataset.json"
289
- for log_msg in train(data_path):
290
- print(log_msg, end="", flush=True)
 
5
  Runs on HuggingFace Spaces (free CPU tier, 16GB RAM).
6
  Fine-tunes OLMo 2 1B Instruct with LoRA to generate video tokens.
7
 
8
+ Writes all logs to a file for the Gradio UI to read.
9
+ Auto-pushes the trained model to EeshaAI/zeeb when done.
 
 
 
 
 
10
  """
11
 
12
  import os
 
15
  import time
16
  import traceback
17
  import gc
18
+ import threading
19
 
20
  import torch
21
  from torch.utils.data import DataLoader, Dataset
 
41
  MAX_GRAD_NORM = 1.0
42
  LOG_EVERY = 1
43
 
44
+
45
+ class _Logger:
46
+ """Thread-safe logger that writes to both stdout and a log file."""
47
+ def __init__(self, log_path):
48
+ self.log_path = log_path
49
+ self.lock = threading.Lock()
50
+ # Initialize log file
51
+ with open(log_path, "w") as f:
52
+ f.write("πŸš€ Zeeb Training Pipeline Starting...\n\n")
53
+
54
+ def log(self, msg):
55
+ with self.lock:
56
+ with open(self.log_path, "a") as f:
57
+ f.write(msg)
58
+ f.flush()
59
+ # Also print to stdout for HF Spaces logs
60
+ print(msg, end="", flush=True)
61
+
62
+
63
  # ---------------------------------------------------------------------------
64
  # Dataset
65
  # ---------------------------------------------------------------------------
 
70
  with open(data_path) as f:
71
  self.data = json.load(f)
72
  self.max_tokens = max_tokens
 
73
 
74
  def __len__(self):
75
  return len(self.data)
 
87
 
88
 
89
  # ---------------------------------------------------------------------------
90
+ # Training (file-based logging)
91
  # ---------------------------------------------------------------------------
92
+ def run_training_to_file(log_path: str = "/tmp/training_log.txt"):
93
+ """Run the full training pipeline, logging to a file."""
94
+ logger = _Logger(log_path)
 
 
 
 
 
95
 
96
  try:
97
+ _run_training(logger)
 
 
 
 
 
 
 
 
 
98
  except Exception as e:
99
+ logger.log(f"\n❌ FATAL ERROR: {e}\n")
100
+ logger.log(traceback.format_exc() + "\n")
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ def _run_training(logger: _Logger):
104
+ """Core training logic."""
105
+
106
+ # ── 1. Load tokenizer ──────────────────────────────────────────────────
107
+ logger.log("πŸ“¦ Loading OLMo 2 1B Instruct tokenizer...\n")
108
+
109
+ from transformers import AutoModelForCausalLM, AutoTokenizer
110
+
111
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
112
+ if tokenizer.pad_token is None:
113
+ tokenizer.pad_token = tokenizer.eos_token
114
+ logger.log(f"βœ… Tokenizer loaded. Vocab size: {len(tokenizer)}\n")
115
+
116
+ # ── 2. Load model ───────────────────────────────────────────────���──────
117
+ logger.log("πŸ“¦ Loading model in float32 on CPU (this takes ~2-3 min)...\n")
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ MODEL_NAME,
120
+ trust_remote_code=True,
121
+ torch_dtype=torch.float32,
122
+ )
123
+ n_params = sum(p.numel() for p in model.parameters()) / 1e6
124
+ logger.log(f"βœ… Model loaded. Parameters: {n_params:.1f}M\n")
125
+
126
+ # ── 3. Expand vocabulary ───────────────────────────────────────────────
127
+ logger.log(f"πŸ”€ Adding {CODEBOOK_SIZE} visual tokens + special tokens...\n")
128
  visual_tokens = [VIDEO_START, VIDEO_END, VIDEO_PAD]
129
  for i in range(CODEBOOK_SIZE):
130
  visual_tokens.append(f"<v_{i}>")
131
 
132
  num_added = tokenizer.add_tokens(visual_tokens)
133
  model.resize_token_embeddings(len(tokenizer))
134
+ logger.log(f"βœ… Added {num_added} tokens. New vocab size: {len(tokenizer)}\n")
135
+
136
+ # ── 4. Apply LoRA ─────────────────────────────────────────────────────
137
+ logger.log(f"πŸ”§ Applying LoRA (r={LORA_R}, alpha={LORA_ALPHA})...\n")
138
+ from peft import LoraConfig, get_peft_model, TaskType
139
+
140
+ lora_config = LoraConfig(
141
+ r=LORA_R,
142
+ lora_alpha=LORA_ALPHA,
143
+ target_modules=["q_proj", "v_proj"],
144
+ lora_dropout=LORA_DROPOUT,
145
+ bias="none",
146
+ task_type=TaskType.CAUSAL_LM,
147
+ )
148
+ model = get_peft_model(model, lora_config)
149
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
150
+ total = sum(p.numel() for p in model.parameters())
151
+ logger.log(f"βœ… LoRA applied. Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)\n")
152
+
153
+ # ── 5. Load dataset ───────────────────────────────────────────────────
154
+ data_path = "tokenized_dataset.json"
155
+ logger.log(f"πŸ“Š Loading dataset from {data_path}...\n")
156
+ dataset = VideoTokenDataset(data_path, max_tokens=256)
157
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
158
+ total_steps = NUM_EPOCHS * len(dataloader)
159
+ logger.log(f"πŸ“Š {len(dataset)} samples Γ— {NUM_EPOCHS} epochs = {total_steps} steps\n")
160
+
161
+ # ── 6. Train ──────────────────────────────────────────────────────────
162
+ logger.log("πŸ”₯ Starting training loop...\n\n")
 
 
 
 
 
 
 
 
 
163
 
164
  optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
165
  model.train()
 
168
  running_loss = 0.0
169
  start_time = time.time()
170
 
171
+ for epoch in range(NUM_EPOCHS):
172
+ epoch_loss = 0.0
173
+ num_batches = 0
174
+
175
+ for batch_idx, batch in enumerate(dataloader):
176
+ prompt = batch["prompt"][0]
177
+ video_tokens = batch["video_tokens"][0]
178
+
179
+ # Format training text (limit to 64 visual tokens for memory)
180
+ token_str = " ".join(f"<v_{t.item()}>" for t in video_tokens[:64])
181
+ text = f"Create a video of: {prompt} {VIDEO_START} {token_str} {VIDEO_END}"
182
+
183
+ inputs = tokenizer(
184
+ text,
185
+ return_tensors="pt",
186
+ truncation=True,
187
+ max_length=MAX_SEQ_LEN,
188
+ padding="max_length",
189
+ )
190
+
191
+ # Forward
192
+ outputs = model(**inputs, labels=inputs["input_ids"])
193
+ loss = outputs.loss / GRADIENT_ACCUMULATION
194
+
195
+ # Backward
196
+ loss.backward()
197
+
198
+ if (batch_idx + 1) % GRADIENT_ACCUMULATION == 0 or (batch_idx + 1) == len(dataloader):
199
+ torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
200
+ optimizer.step()
201
+ optimizer.zero_grad()
202
+
203
+ global_step += 1
204
+ batch_loss = loss.item() * GRADIENT_ACCUMULATION
205
+ epoch_loss += batch_loss
206
+ running_loss += batch_loss
207
+ num_batches += 1
208
+
209
+ elapsed = time.time() - start_time
210
+ steps_per_sec = global_step / elapsed if elapsed > 0 else 0
211
+
212
+ if batch_idx % LOG_EVERY == 0:
213
+ logger.log(
214
+ f" Epoch {epoch + 1}/{NUM_EPOCHS} | "
215
+ f"Step {batch_idx + 1}/{len(dataloader)} | "
216
+ f"Loss: {batch_loss:.4f} | "
217
+ f"Avg: {epoch_loss / num_batches:.4f} | "
218
+ f"Speed: {steps_per_sec:.2f} steps/s\n"
219
  )
220
 
221
+ del outputs, loss
222
+ gc.collect()
 
 
 
 
 
 
 
 
 
223
 
224
+ avg_epoch_loss = epoch_loss / num_batches
225
+ logger.log(f"\nπŸ“ˆ Epoch {epoch + 1} complete. Avg Loss: {avg_epoch_loss:.4f}\n\n")
 
 
 
226
 
227
+ total_time = time.time() - start_time
228
+ logger.log(f"βœ… Training complete in {total_time:.0f}s ({total_time / 60:.1f} min)\n")
229
+ logger.log(f" Final avg loss: {running_loss / global_step:.4f}\n\n")
230
 
231
+ # ── 7. Merge & push ──────────────────────────────────────────────────
232
+ logger.log("πŸ”€ Merging LoRA weights back into base model...\n")
233
+ model = model.merge_and_unload()
234
+ logger.log("βœ… LoRA merged.\n")
 
 
 
 
 
235
 
236
+ logger.log("πŸ’Ύ Saving model locally...\n")
237
+ save_dir = "./trained_model"
238
+ model.save_pretrained(save_dir, safe_serialization=True)
239
+ tokenizer.save_pretrained(save_dir)
240
+ logger.log("βœ… Model saved locally.\n")
241
 
242
+ logger.log(f"πŸš€ Pushing to {REPO_ID}...\n")
243
+ from huggingface_hub import HfApi
244
 
245
+ api = HfApi(token=HF_TOKEN)
 
 
 
 
 
 
 
246
 
 
 
247
  try:
248
+ api.create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
 
249
  except Exception as e:
250
+ logger.log(f"⚠️ Repo note: {e}\n")
251
 
252
+ api.upload_folder(
253
+ folder_path=save_dir,
254
+ repo_id=REPO_ID,
255
+ repo_type="model",
256
+ commit_message=f"LoRA-trained OLMo 2 1B (r={LORA_R}, {NUM_EPOCHS} epochs)",
257
+ )
258
+ logger.log(f"βœ… Model pushed to https://huggingface.co/{REPO_ID}\n")
259
+ logger.log("\nπŸŽ‰ All done! The trained model is now available on HuggingFace.\n")
 
 
 
 
 
 
260
 
 
261
 
262
+ # ---------------------------------------------------------------------------
263
+ # Generator version (for Gradio streaming if needed)
264
+ # ---------------------------------------------------------------------------
265
+ def train(data_path: str = "tokenized_dataset.json"):
266
+ """Generator version that yields log messages."""
267
+ import tempfile
268
+ log_path = tempfile.mktemp(suffix=".txt")
269
+ logger = _Logger(log_path)
270
+
271
+ # Start training in a thread
272
+ t = threading.Thread(target=lambda: _run_training(logger), daemon=True)
273
+ t.start()
274
+
275
+ # Stream log file
276
+ last_pos = 0
277
+ while t.is_alive():
278
+ time.sleep(1)
279
  try:
280
+ with open(log_path, "r") as f:
281
+ f.seek(last_pos)
282
+ new_content = f.read()
283
+ last_pos = f.tell()
284
+ if new_content:
285
+ yield new_content
286
+ except:
287
+ pass
288
+
289
+ # Final read
290
+ time.sleep(1)
291
+ try:
292
+ with open(log_path, "r") as f:
293
+ f.seek(last_pos)
294
+ final = f.read()
295
+ if final:
296
+ yield final
297
+ except:
298
+ pass
299
 
300
 
301
  # ---------------------------------------------------------------------------
302
+ # CLI entry point
303
  # ---------------------------------------------------------------------------
304
  if __name__ == "__main__":
305
  data_path = sys.argv[1] if len(sys.argv) > 1 else "tokenized_dataset.json"
306
+ run_training_to_file("/tmp/training_log.txt")