Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -222,57 +222,87 @@ def log_message(output_log, msg):
|
|
| 222 |
# 🧠 Train model to expand short prompts into long ones
|
| 223 |
# =====================================================
|
| 224 |
@spaces.GPU(duration=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
def train_model(
|
| 226 |
base_model: str,
|
| 227 |
dataset_name: str,
|
| 228 |
-
num_epochs: int,
|
| 229 |
-
batch_size: int,
|
| 230 |
-
learning_rate: float,
|
| 231 |
-
hf_repo: str,
|
| 232 |
):
|
| 233 |
output_log = []
|
| 234 |
|
| 235 |
try:
|
| 236 |
log_message(output_log, "🚀 Initializing prompt expansion training...")
|
| 237 |
|
| 238 |
-
# ===== Device =====
|
| 239 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 240 |
-
|
|
|
|
|
|
|
| 241 |
if device == "cuda":
|
| 242 |
-
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
# ===== Load dataset =====
|
| 245 |
log_message(output_log, f"\n📚 Loading dataset: {dataset_name}")
|
| 246 |
dataset = load_dataset(dataset_name)
|
| 247 |
-
dataset = dataset["train"].train_test_split(test_size=0.2)
|
| 248 |
train_dataset = dataset["train"]
|
| 249 |
test_dataset = dataset["test"]
|
|
|
|
| 250 |
|
| 251 |
-
|
| 252 |
-
log_message(output_log, f" → Test samples: {len(test_dataset)}")
|
| 253 |
-
log_message(output_log, f" → Columns: {train_dataset.column_names}")
|
| 254 |
-
|
| 255 |
-
# =====================================================
|
| 256 |
-
# 🧩 Format training examples
|
| 257 |
-
# Each sample has 'short' (input) and 'long' (target)
|
| 258 |
-
# =====================================================
|
| 259 |
def format_example(example):
|
| 260 |
-
short_prompt = example.get("
|
| 261 |
-
long_response = example.get("
|
| 262 |
|
| 263 |
-
# Compose a structured conversation
|
| 264 |
prompt = (
|
| 265 |
-
f"<|system|>\nYou are
|
| 266 |
f"<|user|>\nShort: {short_prompt}\n"
|
| 267 |
f"<|assistant|>\n{long_response}"
|
| 268 |
)
|
| 269 |
return {"text": prompt}
|
| 270 |
|
| 271 |
-
train_dataset = train_dataset.map(format_example)
|
| 272 |
-
test_dataset = test_dataset.map(format_example)
|
| 273 |
-
log_message(output_log, f"✅ Prepared {len(train_dataset)} train + {len(test_dataset)} test samples")
|
| 274 |
|
| 275 |
-
# ===== Load model & tokenizer =====
|
| 276 |
log_message(output_log, f"\n🤖 Loading base model: {base_model}")
|
| 277 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 278 |
if tokenizer.pad_token is None:
|
|
@@ -280,55 +310,67 @@ def train_model(
|
|
| 280 |
|
| 281 |
model = AutoModelForCausalLM.from_pretrained(
|
| 282 |
base_model,
|
|
|
|
| 283 |
trust_remote_code=True,
|
| 284 |
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 285 |
low_cpu_mem_usage=True,
|
| 286 |
-
|
|
|
|
| 287 |
|
| 288 |
-
|
|
|
|
|
|
|
| 289 |
|
| 290 |
-
# =====
|
| 291 |
-
log_message(output_log, "\n⚙️ Applying LoRA
|
| 292 |
lora_config = LoraConfig(
|
| 293 |
task_type=TaskType.CAUSAL_LM,
|
| 294 |
-
r=
|
| 295 |
-
lora_alpha=
|
| 296 |
lora_dropout=0.1,
|
| 297 |
target_modules=["q_proj", "v_proj"],
|
| 298 |
bias="none",
|
| 299 |
)
|
| 300 |
model = get_peft_model(model, lora_config)
|
| 301 |
-
|
| 302 |
-
log_message(output_log, f"Trainable parameters: {trainable_params:,}")
|
| 303 |
|
| 304 |
-
# ===== Tokenization =====
|
| 305 |
def tokenize_fn(examples):
|
| 306 |
tokenized = tokenizer(
|
| 307 |
examples["text"],
|
| 308 |
padding="max_length",
|
| 309 |
truncation=True,
|
| 310 |
-
max_length=
|
| 311 |
)
|
| 312 |
tokenized["labels"] = tokenized["input_ids"].copy()
|
| 313 |
return tokenized
|
| 314 |
|
| 315 |
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
| 316 |
test_dataset = test_dataset.map(tokenize_fn, batched=True)
|
| 317 |
-
log_message(output_log, "✅ Tokenization
|
| 318 |
|
| 319 |
# ===== Training setup =====
|
| 320 |
output_dir = "./prompt_expander_lora"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
training_args = TrainingArguments(
|
| 322 |
output_dir=output_dir,
|
| 323 |
num_train_epochs=num_epochs,
|
| 324 |
per_device_train_batch_size=batch_size,
|
| 325 |
-
gradient_accumulation_steps=
|
| 326 |
-
warmup_steps=
|
| 327 |
-
logging_steps=
|
| 328 |
save_strategy="epoch",
|
| 329 |
-
fp16=device == "cuda",
|
| 330 |
optim="adamw_torch",
|
| 331 |
learning_rate=learning_rate,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
)
|
| 333 |
|
| 334 |
trainer = Trainer(
|
|
@@ -340,22 +382,25 @@ def train_model(
|
|
| 340 |
)
|
| 341 |
|
| 342 |
# ===== Train =====
|
| 343 |
-
log_message(output_log, "\n🔥 Starting
|
| 344 |
trainer.train()
|
| 345 |
|
| 346 |
-
# ===== Save
|
| 347 |
-
log_message(output_log, "\n💾 Saving fine-tuned model
|
| 348 |
trainer.save_model(output_dir)
|
| 349 |
tokenizer.save_pretrained(output_dir)
|
| 350 |
|
| 351 |
# ===== Upload to Hub =====
|
| 352 |
-
|
| 353 |
-
|
|
|
|
| 354 |
|
| 355 |
-
log_message(output_log, "\n✅ Training complete
|
| 356 |
|
|
|
|
|
|
|
| 357 |
except Exception as e:
|
| 358 |
-
log_message(output_log, f"\n❌
|
| 359 |
|
| 360 |
return "\n".join(output_log)
|
| 361 |
|
|
|
|
| 222 |
# 🧠 Train model to expand short prompts into long ones
|
| 223 |
# =====================================================
|
| 224 |
@spaces.GPU(duration=300)
|
| 225 |
+
import torch
|
| 226 |
+
from datasets import load_dataset
|
| 227 |
+
from transformers import (
|
| 228 |
+
AutoTokenizer,
|
| 229 |
+
AutoModelForCausalLM,
|
| 230 |
+
Trainer,
|
| 231 |
+
TrainingArguments,
|
| 232 |
+
)
|
| 233 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 234 |
+
import os
|
| 235 |
+
|
| 236 |
+
# =====================================================
|
| 237 |
+
# 🔧 Utility logging
|
| 238 |
+
# =====================================================
|
| 239 |
+
def log_message(output_log, msg):
|
| 240 |
+
print(msg)
|
| 241 |
+
output_log.append(msg)
|
| 242 |
+
|
| 243 |
+
def start_async_upload(output_dir, hf_repo, output_log):
|
| 244 |
+
from huggingface_hub import upload_folder
|
| 245 |
+
try:
|
| 246 |
+
upload_folder(
|
| 247 |
+
repo_id=hf_repo,
|
| 248 |
+
folder_path=output_dir,
|
| 249 |
+
repo_type="model",
|
| 250 |
+
commit_message="Upload fine-tuned model"
|
| 251 |
+
)
|
| 252 |
+
log_message(output_log, f"☁️ Model uploaded to {hf_repo}")
|
| 253 |
+
except Exception as e:
|
| 254 |
+
log_message(output_log, f"⚠️ Upload failed: {e}")
|
| 255 |
+
|
| 256 |
+
# =====================================================
|
| 257 |
+
# 🧠 GPU-safe training for short→long prompt expansion
|
| 258 |
+
# =====================================================
|
| 259 |
def train_model(
|
| 260 |
base_model: str,
|
| 261 |
dataset_name: str,
|
| 262 |
+
num_epochs: int = 1,
|
| 263 |
+
batch_size: int = 1,
|
| 264 |
+
learning_rate: float = 2e-4,
|
| 265 |
+
hf_repo: str = None,
|
| 266 |
):
|
| 267 |
output_log = []
|
| 268 |
|
| 269 |
try:
|
| 270 |
log_message(output_log, "🚀 Initializing prompt expansion training...")
|
| 271 |
|
| 272 |
+
# ===== Device setup =====
|
| 273 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 274 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
| 275 |
+
log_message(output_log, f"🎮 Device: {device}, dtype: {dtype}")
|
| 276 |
+
|
| 277 |
if device == "cuda":
|
| 278 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 279 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 280 |
+
log_message(output_log, f"✅ GPU: {gpu_name} ({gpu_mem:.1f} GB)")
|
| 281 |
|
| 282 |
+
# ===== Load dataset safely =====
|
| 283 |
log_message(output_log, f"\n📚 Loading dataset: {dataset_name}")
|
| 284 |
dataset = load_dataset(dataset_name)
|
| 285 |
+
dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
|
| 286 |
train_dataset = dataset["train"]
|
| 287 |
test_dataset = dataset["test"]
|
| 288 |
+
log_message(output_log, f" → Train samples: {len(train_dataset)} | Test samples: {len(test_dataset)}")
|
| 289 |
|
| 290 |
+
# ===== Format examples =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
def format_example(example):
|
| 292 |
+
short_prompt = example.get("short", "").strip()
|
| 293 |
+
long_response = example.get("long", "").strip()
|
| 294 |
|
|
|
|
| 295 |
prompt = (
|
| 296 |
+
f"<|system|>\nYou are an AI that expands short prompts into detailed, descriptive versions.\n"
|
| 297 |
f"<|user|>\nShort: {short_prompt}\n"
|
| 298 |
f"<|assistant|>\n{long_response}"
|
| 299 |
)
|
| 300 |
return {"text": prompt}
|
| 301 |
|
| 302 |
+
train_dataset = train_dataset.map(format_example, num_proc=1)
|
| 303 |
+
test_dataset = test_dataset.map(format_example, num_proc=1)
|
|
|
|
| 304 |
|
| 305 |
+
# ===== Load model & tokenizer safely =====
|
| 306 |
log_message(output_log, f"\n🤖 Loading base model: {base_model}")
|
| 307 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 308 |
if tokenizer.pad_token is None:
|
|
|
|
| 310 |
|
| 311 |
model = AutoModelForCausalLM.from_pretrained(
|
| 312 |
base_model,
|
| 313 |
+
torch_dtype=dtype,
|
| 314 |
trust_remote_code=True,
|
|
|
|
| 315 |
low_cpu_mem_usage=True,
|
| 316 |
+
device_map="auto" if device == "cuda" else None,
|
| 317 |
+
)
|
| 318 |
|
| 319 |
+
# Enable memory optimizations
|
| 320 |
+
model.gradient_checkpointing_enable()
|
| 321 |
+
log_message(output_log, "✅ Model loaded with gradient checkpointing")
|
| 322 |
|
| 323 |
+
# ===== Apply lightweight LoRA =====
|
| 324 |
+
log_message(output_log, "\n⚙️ Applying LoRA fine-tuning config...")
|
| 325 |
lora_config = LoraConfig(
|
| 326 |
task_type=TaskType.CAUSAL_LM,
|
| 327 |
+
r=4,
|
| 328 |
+
lora_alpha=8,
|
| 329 |
lora_dropout=0.1,
|
| 330 |
target_modules=["q_proj", "v_proj"],
|
| 331 |
bias="none",
|
| 332 |
)
|
| 333 |
model = get_peft_model(model, lora_config)
|
| 334 |
+
log_message(output_log, f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
|
|
|
| 335 |
|
| 336 |
+
# ===== Tokenization (batched for speed) =====
|
| 337 |
def tokenize_fn(examples):
|
| 338 |
tokenized = tokenizer(
|
| 339 |
examples["text"],
|
| 340 |
padding="max_length",
|
| 341 |
truncation=True,
|
| 342 |
+
max_length=384,
|
| 343 |
)
|
| 344 |
tokenized["labels"] = tokenized["input_ids"].copy()
|
| 345 |
return tokenized
|
| 346 |
|
| 347 |
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
| 348 |
test_dataset = test_dataset.map(tokenize_fn, batched=True)
|
| 349 |
+
log_message(output_log, "✅ Tokenization done")
|
| 350 |
|
| 351 |
# ===== Training setup =====
|
| 352 |
output_dir = "./prompt_expander_lora"
|
| 353 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 354 |
+
|
| 355 |
+
# Automatically reduce batch size for low GPU VRAM
|
| 356 |
+
if device == "cuda" and gpu_mem < 10:
|
| 357 |
+
batch_size = 1
|
| 358 |
+
log_message(output_log, f"⚠️ GPU memory low → Using batch_size={batch_size}")
|
| 359 |
+
|
| 360 |
training_args = TrainingArguments(
|
| 361 |
output_dir=output_dir,
|
| 362 |
num_train_epochs=num_epochs,
|
| 363 |
per_device_train_batch_size=batch_size,
|
| 364 |
+
gradient_accumulation_steps=4,
|
| 365 |
+
warmup_steps=20,
|
| 366 |
+
logging_steps=10,
|
| 367 |
save_strategy="epoch",
|
|
|
|
| 368 |
optim="adamw_torch",
|
| 369 |
learning_rate=learning_rate,
|
| 370 |
+
fp16=(dtype == torch.float16),
|
| 371 |
+
bf16=(dtype == torch.bfloat16),
|
| 372 |
+
max_grad_norm=1.0,
|
| 373 |
+
report_to="none",
|
| 374 |
)
|
| 375 |
|
| 376 |
trainer = Trainer(
|
|
|
|
| 382 |
)
|
| 383 |
|
| 384 |
# ===== Train =====
|
| 385 |
+
log_message(output_log, "\n🔥 Starting safe LoRA fine-tuning...")
|
| 386 |
trainer.train()
|
| 387 |
|
| 388 |
+
# ===== Save =====
|
| 389 |
+
log_message(output_log, "\n💾 Saving fine-tuned model...")
|
| 390 |
trainer.save_model(output_dir)
|
| 391 |
tokenizer.save_pretrained(output_dir)
|
| 392 |
|
| 393 |
# ===== Upload to Hub =====
|
| 394 |
+
if hf_repo:
|
| 395 |
+
log_message(output_log, f"\n☁️ Uploading to {hf_repo} ...")
|
| 396 |
+
start_async_upload(output_dir, hf_repo, output_log)
|
| 397 |
|
| 398 |
+
log_message(output_log, "\n✅ Training complete!")
|
| 399 |
|
| 400 |
+
except torch.cuda.OutOfMemoryError:
|
| 401 |
+
log_message(output_log, "\n❌ CUDA OOM — try lowering batch size or sequence length.")
|
| 402 |
except Exception as e:
|
| 403 |
+
log_message(output_log, f"\n❌ Training error: {e}")
|
| 404 |
|
| 405 |
return "\n".join(output_log)
|
| 406 |
|