Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -225,69 +225,64 @@ def log_message(output_log, msg):
|
|
| 225 |
def train_model(
|
| 226 |
base_model: str,
|
| 227 |
dataset_name: str,
|
|
|
|
| 228 |
num_epochs: int = 1,
|
| 229 |
-
batch_size: int =
|
| 230 |
-
learning_rate: float =
|
| 231 |
-
hf_repo: str = None,
|
| 232 |
):
|
| 233 |
output_log = []
|
| 234 |
|
| 235 |
try:
|
| 236 |
-
log_message(output_log, "๐
|
| 237 |
|
| 238 |
-
# ===== Device
|
| 239 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 240 |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
| 241 |
log_message(output_log, f"๐ฎ Device: {device}, dtype: {dtype}")
|
| 242 |
-
|
| 243 |
if device == "cuda":
|
| 244 |
-
|
| 245 |
-
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 246 |
-
log_message(output_log, f"โ
GPU: {gpu_name} ({gpu_mem:.1f} GB)")
|
| 247 |
|
| 248 |
-
# ===== Load dataset
|
| 249 |
log_message(output_log, f"\n๐ Loading dataset: {dataset_name}")
|
| 250 |
dataset = load_dataset(dataset_name)
|
| 251 |
dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
|
| 252 |
-
train_dataset = dataset["train"]
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
-
# ===== Format
|
| 257 |
def format_example(example):
|
| 258 |
short_prompt = example.get("short", "").strip()
|
| 259 |
long_response = example.get("long", "").strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
f"<|user|>\nShort: {short_prompt}\n"
|
| 264 |
-
f"<|assistant|>\n{long_response}"
|
| 265 |
-
)
|
| 266 |
-
return {"text": prompt}
|
| 267 |
-
|
| 268 |
-
train_dataset = train_dataset.map(format_example, num_proc=1)
|
| 269 |
-
test_dataset = test_dataset.map(format_example, num_proc=1)
|
| 270 |
|
| 271 |
-
# =====
|
| 272 |
-
log_message(output_log, f"\n๐ค Loading base model: {base_model}")
|
| 273 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 274 |
if tokenizer.pad_token is None:
|
| 275 |
tokenizer.pad_token = tokenizer.eos_token
|
| 276 |
|
| 277 |
model = AutoModelForCausalLM.from_pretrained(
|
| 278 |
base_model,
|
| 279 |
-
torch_dtype=dtype,
|
| 280 |
trust_remote_code=True,
|
| 281 |
-
|
| 282 |
device_map="auto" if device == "cuda" else None,
|
|
|
|
| 283 |
)
|
| 284 |
-
|
| 285 |
-
# Enable memory optimizations
|
| 286 |
model.gradient_checkpointing_enable()
|
| 287 |
-
log_message(output_log, "โ
Model loaded with gradient checkpointing")
|
| 288 |
|
| 289 |
-
# =====
|
| 290 |
-
log_message(output_log, "\nโ๏ธ Applying LoRA fine-tuning config...")
|
| 291 |
lora_config = LoraConfig(
|
| 292 |
task_type=TaskType.CAUSAL_LM,
|
| 293 |
r=4,
|
|
@@ -297,46 +292,39 @@ def train_model(
|
|
| 297 |
bias="none",
|
| 298 |
)
|
| 299 |
model = get_peft_model(model, lora_config)
|
| 300 |
-
log_message(output_log,
|
| 301 |
|
| 302 |
-
# ===== Tokenization
|
| 303 |
def tokenize_fn(examples):
|
| 304 |
tokenized = tokenizer(
|
| 305 |
examples["text"],
|
| 306 |
padding="max_length",
|
| 307 |
truncation=True,
|
| 308 |
-
max_length=
|
| 309 |
)
|
| 310 |
tokenized["labels"] = tokenized["input_ids"].copy()
|
| 311 |
return tokenized
|
| 312 |
|
| 313 |
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
| 314 |
test_dataset = test_dataset.map(tokenize_fn, batched=True)
|
| 315 |
-
log_message(output_log, "โ
Tokenization done")
|
| 316 |
|
| 317 |
# ===== Training setup =====
|
| 318 |
-
output_dir = "./
|
| 319 |
os.makedirs(output_dir, exist_ok=True)
|
| 320 |
|
| 321 |
-
# Automatically reduce batch size for low GPU VRAM
|
| 322 |
-
if device == "cuda" and gpu_mem < 10:
|
| 323 |
-
batch_size = 1
|
| 324 |
-
log_message(output_log, f"โ ๏ธ GPU memory low โ Using batch_size={batch_size}")
|
| 325 |
-
|
| 326 |
training_args = TrainingArguments(
|
| 327 |
output_dir=output_dir,
|
| 328 |
num_train_epochs=num_epochs,
|
| 329 |
per_device_train_batch_size=batch_size,
|
| 330 |
-
gradient_accumulation_steps=
|
| 331 |
-
warmup_steps=
|
| 332 |
-
logging_steps=
|
| 333 |
-
save_strategy="
|
| 334 |
-
optim="adamw_torch",
|
| 335 |
-
learning_rate=learning_rate,
|
| 336 |
fp16=(dtype == torch.float16),
|
| 337 |
bf16=(dtype == torch.bfloat16),
|
| 338 |
-
|
| 339 |
report_to="none",
|
|
|
|
| 340 |
)
|
| 341 |
|
| 342 |
trainer = Trainer(
|
|
@@ -348,25 +336,26 @@ def train_model(
|
|
| 348 |
)
|
| 349 |
|
| 350 |
# ===== Train =====
|
| 351 |
-
log_message(output_log, "\n๐ฅ
|
| 352 |
trainer.train()
|
| 353 |
|
| 354 |
-
# ===== Save =====
|
| 355 |
-
log_message(output_log, "\n๐พ Saving fine-tuned model...")
|
| 356 |
-
|
| 357 |
tokenizer.save_pretrained(output_dir)
|
| 358 |
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
-
log_message(output_log, "\nโ
|
| 365 |
|
| 366 |
-
except torch.cuda.OutOfMemoryError:
|
| 367 |
-
log_message(output_log, "\nโ CUDA OOM โ try lowering batch size or sequence length.")
|
| 368 |
except Exception as e:
|
| 369 |
-
log_message(output_log, f"
|
| 370 |
|
| 371 |
return "\n".join(output_log)
|
| 372 |
|
|
|
|
| 225 |
def train_model(
|
| 226 |
base_model: str,
|
| 227 |
dataset_name: str,
|
| 228 |
+
hf_repo: str,
|
| 229 |
num_epochs: int = 1,
|
| 230 |
+
batch_size: int = 2,
|
| 231 |
+
learning_rate: float = 5e-4,
|
|
|
|
| 232 |
):
|
| 233 |
output_log = []
|
| 234 |
|
| 235 |
try:
|
| 236 |
+
log_message(output_log, "๐ Starting FAST test training...")
|
| 237 |
|
| 238 |
+
# ===== Device =====
|
| 239 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 240 |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
| 241 |
log_message(output_log, f"๐ฎ Device: {device}, dtype: {dtype}")
|
|
|
|
| 242 |
if device == "cuda":
|
| 243 |
+
log_message(output_log, f"โ
GPU: {torch.cuda.get_device_name(0)}")
|
|
|
|
|
|
|
| 244 |
|
| 245 |
+
# ===== Load dataset =====
|
| 246 |
log_message(output_log, f"\n๐ Loading dataset: {dataset_name}")
|
| 247 |
dataset = load_dataset(dataset_name)
|
| 248 |
dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
|
| 249 |
+
train_dataset, test_dataset = dataset["train"], dataset["test"]
|
| 250 |
+
|
| 251 |
+
# ===== โก FAST mode: use small subset =====
|
| 252 |
+
train_dataset = train_dataset.select(range(min(100, len(train_dataset))))
|
| 253 |
+
test_dataset = test_dataset.select(range(min(20, len(test_dataset))))
|
| 254 |
+
log_message(output_log, f"โก Using {len(train_dataset)} train / {len(test_dataset)} test samples")
|
| 255 |
|
| 256 |
+
# ===== Format samples =====
|
| 257 |
def format_example(example):
|
| 258 |
short_prompt = example.get("short", "").strip()
|
| 259 |
long_response = example.get("long", "").strip()
|
| 260 |
+
return {
|
| 261 |
+
"text": (
|
| 262 |
+
f"<|system|>\nYou are an AI that expands short prompts into detailed, descriptive ones.\n"
|
| 263 |
+
f"<|user|>\nShort: {short_prompt}\n"
|
| 264 |
+
f"<|assistant|>\n{long_response}"
|
| 265 |
+
)
|
| 266 |
+
}
|
| 267 |
|
| 268 |
+
train_dataset = train_dataset.map(format_example)
|
| 269 |
+
test_dataset = test_dataset.map(format_example)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
+
# ===== Tokenizer & Model =====
|
|
|
|
| 272 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
| 273 |
if tokenizer.pad_token is None:
|
| 274 |
tokenizer.pad_token = tokenizer.eos_token
|
| 275 |
|
| 276 |
model = AutoModelForCausalLM.from_pretrained(
|
| 277 |
base_model,
|
|
|
|
| 278 |
trust_remote_code=True,
|
| 279 |
+
torch_dtype=dtype,
|
| 280 |
device_map="auto" if device == "cuda" else None,
|
| 281 |
+
low_cpu_mem_usage=True,
|
| 282 |
)
|
|
|
|
|
|
|
| 283 |
model.gradient_checkpointing_enable()
|
|
|
|
| 284 |
|
| 285 |
+
# ===== LoRA setup =====
|
|
|
|
| 286 |
lora_config = LoraConfig(
|
| 287 |
task_type=TaskType.CAUSAL_LM,
|
| 288 |
r=4,
|
|
|
|
| 292 |
bias="none",
|
| 293 |
)
|
| 294 |
model = get_peft_model(model, lora_config)
|
| 295 |
+
log_message(output_log, "โ
LoRA applied successfully")
|
| 296 |
|
| 297 |
+
# ===== Tokenization =====
|
| 298 |
def tokenize_fn(examples):
|
| 299 |
tokenized = tokenizer(
|
| 300 |
examples["text"],
|
| 301 |
padding="max_length",
|
| 302 |
truncation=True,
|
| 303 |
+
max_length=256,
|
| 304 |
)
|
| 305 |
tokenized["labels"] = tokenized["input_ids"].copy()
|
| 306 |
return tokenized
|
| 307 |
|
| 308 |
train_dataset = train_dataset.map(tokenize_fn, batched=True)
|
| 309 |
test_dataset = test_dataset.map(tokenize_fn, batched=True)
|
|
|
|
| 310 |
|
| 311 |
# ===== Training setup =====
|
| 312 |
+
output_dir = "./prompt_expander_fast"
|
| 313 |
os.makedirs(output_dir, exist_ok=True)
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
training_args = TrainingArguments(
|
| 316 |
output_dir=output_dir,
|
| 317 |
num_train_epochs=num_epochs,
|
| 318 |
per_device_train_batch_size=batch_size,
|
| 319 |
+
gradient_accumulation_steps=2,
|
| 320 |
+
warmup_steps=5,
|
| 321 |
+
logging_steps=5,
|
| 322 |
+
save_strategy="no", # don't save checkpoints
|
|
|
|
|
|
|
| 323 |
fp16=(dtype == torch.float16),
|
| 324 |
bf16=(dtype == torch.bfloat16),
|
| 325 |
+
learning_rate=learning_rate,
|
| 326 |
report_to="none",
|
| 327 |
+
optim="adamw_torch",
|
| 328 |
)
|
| 329 |
|
| 330 |
trainer = Trainer(
|
|
|
|
| 336 |
)
|
| 337 |
|
| 338 |
# ===== Train =====
|
| 339 |
+
log_message(output_log, "\n๐ฅ Quick training started...")
|
| 340 |
trainer.train()
|
| 341 |
|
| 342 |
+
# ===== Save + Upload =====
|
| 343 |
+
log_message(output_log, "\n๐พ Saving fast fine-tuned model...")
|
| 344 |
+
model.save_pretrained(output_dir)
|
| 345 |
tokenizer.save_pretrained(output_dir)
|
| 346 |
|
| 347 |
+
log_message(output_log, f"โ๏ธ Uploading model to {hf_repo} ...")
|
| 348 |
+
upload_folder(
|
| 349 |
+
repo_id=hf_repo,
|
| 350 |
+
folder_path=output_dir,
|
| 351 |
+
repo_type="model",
|
| 352 |
+
commit_message="Quick test fine-tune upload",
|
| 353 |
+
)
|
| 354 |
|
| 355 |
+
log_message(output_log, "\nโ
FAST training completed successfully!")
|
| 356 |
|
|
|
|
|
|
|
| 357 |
except Exception as e:
|
| 358 |
+
log_message(output_log, f"โ Error: {e}")
|
| 359 |
|
| 360 |
return "\n".join(output_log)
|
| 361 |
|