Spaces:
Running
Running
Update app_gpu.py
Browse files- app_gpu.py +82 -17
app_gpu.py
CHANGED
|
@@ -409,25 +409,90 @@ for name, module in model.named_modules():
|
|
| 409 |
# 5️⃣ CODE EXPLAIN TAB
|
| 410 |
# =========================================================
|
| 411 |
with gr.Tab("Code Explain"):
|
| 412 |
-
|
| 413 |
-
Universal Dynamic LoRA Trainer & Inference — Code Explanation
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
print(f"[INFO] Loading base model: {base_model}")
|
| 425 |
-
# -> Loads base model
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
print("[INFO] Saving LoRA adapter...")
|
| 430 |
-
# ->
|
|
|
|
|
|
|
|
|
|
| 431 |
# [SUCCESS] LoRA uploaded successfully 🚀
|
| 432 |
""")
|
| 433 |
return demo
|
|
|
|
| 409 |
# 5️⃣ CODE EXPLAIN TAB
|
| 410 |
# =========================================================
|
| 411 |
with gr.Tab("Code Explain"):
|
| 412 |
+
explain_md = gr.Markdown("""
|
| 413 |
+
### 🧩 Universal Dynamic LoRA Trainer & Inference — Code Explanation
|
| 414 |
+
This project provides an **end-to-end LoRA fine-tuning and inference system** for language models like **Gemma**, built with **Gradio**, **PEFT**, and **Accelerate**.
|
| 415 |
+
It supports both **training new LoRAs** and **generating text** with existing ones — all in a single interface.
|
| 416 |
+
---
|
| 417 |
+
#### **1️⃣ Imports Overview**
|
| 418 |
+
- **Core libs:** `os`, `torch`, `gradio`, `numpy`, `pandas`
|
| 419 |
+
- **Training libs:** `peft` (`LoraConfig`, `get_peft_model`), `accelerate` (`Accelerator`)
|
| 420 |
+
- **Modeling:** `transformers` (for Gemma base model)
|
| 421 |
+
- **Hub integration:** `huggingface_hub` (for uploading adapters)
|
| 422 |
+
- **Spaces:** `spaces` — for execution within Hugging Face Spaces
|
| 423 |
+
---
|
| 424 |
+
#### **2️⃣ Dataset Loading**
|
| 425 |
+
- Uses a lightweight **MediaTextDataset** class to load:
|
| 426 |
+
- CSV / Parquet files
|
| 427 |
+
- or directly from a Hugging Face dataset repo
|
| 428 |
+
- Expects two columns:
|
| 429 |
+
`short_prompt` → Input text
|
| 430 |
+
`long_prompt` → Target expanded text
|
| 431 |
+
- Supports batching, missing-column checks, and configurable max record limits.
|
| 432 |
+
---
|
| 433 |
+
#### **3️⃣ Model Loading & Preparation**
|
| 434 |
+
- Loads **Gemma model and tokenizer** via `AutoModelForCausalLM` and `AutoTokenizer`.
|
| 435 |
+
- Automatically detects **target modules** (e.g. `q_proj`, `v_proj`) for LoRA injection.
|
| 436 |
+
- Supports `float16` or `bfloat16` precision with `Accelerator` for optimal memory usage.
|
| 437 |
+
---
|
| 438 |
+
#### **4️⃣ LoRA Training Logic**
|
| 439 |
+
- Core formula:
|
| 440 |
+
\[
|
| 441 |
+
W_{eff} = W + \alpha \times (B @ A)
|
| 442 |
+
\]
|
| 443 |
+
- Only **A** and **B** matrices are trainable; base model weights remain frozen.
|
| 444 |
+
- Configurable parameters:
|
| 445 |
+
`r` (rank), `alpha` (scaling), `epochs`, `lr`, `batch_size`
|
| 446 |
+
- Training logs stream live in the UI, showing step-by-step loss values.
|
| 447 |
+
- After training, the adapter is **saved locally** and **uploaded to Hugging Face Hub**.
|
| 448 |
+
---
|
| 449 |
+
#### **5️⃣ CPU Inference Mode**
|
| 450 |
+
- Runs entirely on **CPU**, no GPU required.
|
| 451 |
+
- Loads base Gemma model + trained LoRA weights (`PeftModel.from_pretrained`).
|
| 452 |
+
- Optionally merges LoRA with base model.
|
| 453 |
+
- Expands the short prompt → long descriptive text using standard generation parameters (e.g., top-p / top-k sampling).
|
| 454 |
+
---
|
| 455 |
+
#### **6️⃣ LoRA Internals Explained**
|
| 456 |
+
- LoRA injects low-rank matrices (A, B) into **attention Linear layers**.
|
| 457 |
+
- Example:
|
| 458 |
+
\[
|
| 459 |
+
Q_{new} = Q + \alpha \times (B @ A)
|
| 460 |
+
\]
|
| 461 |
+
- Significantly reduces training cost:
|
| 462 |
+
- Memory: ~1–2% of full model
|
| 463 |
+
- Compute: trains faster with minimal GPU load
|
| 464 |
+
- Scalable to large models like Gemma 3B / 4B with rank ≤ 16.
|
| 465 |
+
---
|
| 466 |
+
#### **7️⃣ Gradio UI Structure**
|
| 467 |
+
- **Train LoRA Tab:**
|
| 468 |
+
Configure model, dataset, LoRA parameters, and upload target.
|
| 469 |
+
Press **🚀 Start Training** to stream training logs live.
|
| 470 |
+
- **Inference (CPU) Tab:**
|
| 471 |
+
Type a short prompt → Generates expanded long-form version via trained LoRA.
|
| 472 |
+
- **Code Explain Tab:**
|
| 473 |
+
Detailed breakdown of logic + simulated console output below.
|
| 474 |
+
---
|
| 475 |
+
### 🧾 Example Log Simulation
|
| 476 |
+
```python
|
| 477 |
print(f"[INFO] Loading base model: {base_model}")
|
| 478 |
+
# -> Loads Gemma base model (fp16) on CUDA
|
| 479 |
+
# [INFO] Base model google/gemma-3-4b-it loaded successfully
|
| 480 |
+
print(f"[INFO] Preparing dataset from: {dataset_path}")
|
| 481 |
+
# -> Loads dataset or CSV file
|
| 482 |
+
# [DATA] 980 samples loaded, columns: short_prompt, long_prompt
|
| 483 |
+
print("[INFO] Initializing LoRA configuration...")
|
| 484 |
+
# -> Creates LoraConfig(r=8, alpha=16, target_modules=['q_proj', 'v_proj'])
|
| 485 |
+
# [CONFIG] LoRA applied to 96 attention layers
|
| 486 |
+
print("[INFO] Starting training loop...")
|
| 487 |
+
# [TRAIN] Step 1 | Loss: 2.31
|
| 488 |
+
# [TRAIN] Step 50 | Loss: 1.42
|
| 489 |
+
# [TRAIN] Step 100 | Loss: 0.91
|
| 490 |
+
# [TRAIN] Epoch 1 complete (avg loss: 1.21)
|
| 491 |
print("[INFO] Saving LoRA adapter...")
|
| 492 |
+
# -> Saves safetensors and config locally
|
| 493 |
+
print(f"[UPLOAD] Pushing adapter to {hf_repo_id}")
|
| 494 |
+
# -> Uploads model to Hugging Face Hub
|
| 495 |
+
# [UPLOAD] adapter_model.safetensors (67.7 MB)
|
| 496 |
# [SUCCESS] LoRA uploaded successfully 🚀
|
| 497 |
""")
|
| 498 |
return demo
|