Spaces:
Running
Running
Update app_gpu.py
Browse files- app_gpu.py +141 -162
app_gpu.py
CHANGED
|
@@ -252,182 +252,161 @@ def generate_long_prompt_cpu(base_model, lora_repo, short_prompt, max_length=200
|
|
| 252 |
import gradio as gr
|
| 253 |
|
| 254 |
def run_ui():
|
| 255 |
-
with gr.Blocks() as demo:
|
| 256 |
-
gr.Markdown("#
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
with gr.
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_),
|
| 283 |
-
batch_size=int(batch), num_workers=int(num_w),
|
| 284 |
-
max_train_records=int(max_rec), hf_repo_id=repo_
|
| 285 |
)
|
| 286 |
-
for item in gen:
|
| 287 |
-
yield item
|
| 288 |
-
|
| 289 |
-
btn = gr.Button("π Start Training")
|
| 290 |
-
btn.click(
|
| 291 |
-
fn=launch_train,
|
| 292 |
-
inputs=[
|
| 293 |
-
base_model, dataset, csvname, short_col, long_col,
|
| 294 |
-
batch_size, num_workers, r, a, ep, lr, max_records, repo
|
| 295 |
-
],
|
| 296 |
-
outputs=[logs],
|
| 297 |
-
queue=True
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
# ---------------- Inference (CPU) Tab ----------------
|
| 301 |
-
with gr.Tab("Inference (CPU)"):
|
| 302 |
-
inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
|
| 303 |
-
inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
|
| 304 |
-
short_prompt = gr.Textbox(label="Short prompt")
|
| 305 |
-
long_prompt_out = gr.Textbox(label="Generated long prompt", lines=5)
|
| 306 |
-
|
| 307 |
-
inf_btn = gr.Button("π Generate Long Prompt")
|
| 308 |
-
inf_btn.click(
|
| 309 |
-
fn=generate_long_prompt_cpu,
|
| 310 |
-
inputs=[inf_base_model, inf_lora_repo, short_prompt],
|
| 311 |
-
outputs=[long_prompt_out]
|
| 312 |
-
)
|
| 313 |
-
|
| 314 |
-
# ---------------- Code Explain Tab ----------------
|
| 315 |
-
with gr.Tab("Code Explain"):
|
| 316 |
-
explain_md = gr.Markdown("""
|
| 317 |
-
### π§© Universal Dynamic LoRA Trainer & Inference β Code Explanation
|
| 318 |
-
|
| 319 |
-
This project provides an **end-to-end LoRA fine-tuning and inference system** for language models like **Gemma**, built with **Gradio**, **PEFT**, and **Accelerate**.
|
| 320 |
-
It supports both **training new LoRAs** and **generating text** with existing ones β all in a single interface.
|
| 321 |
-
|
| 322 |
-
---
|
| 323 |
-
|
| 324 |
-
#### **1οΈβ£ Imports Overview**
|
| 325 |
-
- **Core libs:** `os`, `torch`, `gradio`, `numpy`, `pandas`
|
| 326 |
-
- **Training libs:** `peft` (`LoraConfig`, `get_peft_model`), `accelerate` (`Accelerator`)
|
| 327 |
-
- **Modeling:** `transformers` (for Gemma base model)
|
| 328 |
-
- **Hub integration:** `huggingface_hub` (for uploading adapters)
|
| 329 |
-
- **Spaces:** `spaces` β for execution within Hugging Face Spaces
|
| 330 |
-
|
| 331 |
-
---
|
| 332 |
-
|
| 333 |
-
#### **2οΈβ£ Dataset Loading**
|
| 334 |
-
- Uses a lightweight **MediaTextDataset** class to load:
|
| 335 |
-
- CSV / Parquet files
|
| 336 |
-
- or directly from a Hugging Face dataset repo
|
| 337 |
-
- Expects two columns:
|
| 338 |
-
`short_prompt` β Input text
|
| 339 |
-
`long_prompt` β Target expanded text
|
| 340 |
-
- Supports batching, missing-column checks, and configurable max record limits.
|
| 341 |
-
|
| 342 |
-
---
|
| 343 |
-
|
| 344 |
-
#### **3οΈβ£ Model Loading & Preparation**
|
| 345 |
-
- Loads **Gemma model and tokenizer** via `AutoModelForCausalLM` and `AutoTokenizer`.
|
| 346 |
-
- Automatically detects **target modules** (e.g. `q_proj`, `v_proj`) for LoRA injection.
|
| 347 |
-
- Supports `float16` or `bfloat16` precision with `Accelerator` for optimal memory usage.
|
| 348 |
-
|
| 349 |
-
---
|
| 350 |
|
| 351 |
-
#
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
- Example:
|
| 375 |
-
\[
|
| 376 |
-
Q_{new} = Q + \alpha \times (B @ A)
|
| 377 |
-
\]
|
| 378 |
-
- Significantly reduces training cost:
|
| 379 |
-
- Memory: ~1β2% of full model
|
| 380 |
-
- Compute: trains faster with minimal GPU load
|
| 381 |
-
- Scalable to large models like Gemma 3B / 4B with rank β€ 16.
|
| 382 |
|
| 383 |
---
|
| 384 |
|
| 385 |
-
####
|
| 386 |
-
- **Train LoRA Tab:**
|
| 387 |
-
Configure model, dataset, LoRA parameters, and upload target.
|
| 388 |
-
Press **π Start Training** to stream training logs live.
|
| 389 |
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
---
|
| 397 |
|
| 398 |
-
|
| 399 |
|
| 400 |
```python
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
# -> Loads dataset or CSV file
|
| 407 |
-
# [DATA] 980 samples loaded, columns: short_prompt, long_prompt
|
| 408 |
-
|
| 409 |
-
print("[INFO] Initializing LoRA configuration...")
|
| 410 |
-
# -> Creates LoraConfig(r=8, alpha=16, target_modules=['q_proj', 'v_proj'])
|
| 411 |
-
# [CONFIG] LoRA applied to 96 attention layers
|
| 412 |
-
|
| 413 |
-
print("[INFO] Starting training loop...")
|
| 414 |
-
# [TRAIN] Step 1 | Loss: 2.31
|
| 415 |
-
# [TRAIN] Step 50 | Loss: 1.42
|
| 416 |
-
# [TRAIN] Step 100 | Loss: 0.91
|
| 417 |
-
# [TRAIN] Epoch 1 complete (avg loss: 1.21)
|
| 418 |
-
|
| 419 |
-
print("[INFO] Saving LoRA adapter...")
|
| 420 |
-
# -> Saves safetensors and config locally
|
| 421 |
-
|
| 422 |
-
print(f"[UPLOAD] Pushing adapter to {hf_repo_id}")
|
| 423 |
-
# -> Uploads model to Hugging Face Hub
|
| 424 |
-
# [UPLOAD] adapter_model.safetensors (67.7 MB)
|
| 425 |
-
# [SUCCESS] LoRA uploaded successfully π
|
| 426 |
-
""")
|
| 427 |
|
| 428 |
return demo
|
| 429 |
|
| 430 |
-
|
| 431 |
|
| 432 |
|
| 433 |
|
|
|
|
| 252 |
import gradio as gr
|
| 253 |
|
| 254 |
def run_ui():
|
| 255 |
+
with gr.Blocks(title="Prompt Enhancer Trainer + Inference UI") as demo:
|
| 256 |
+
gr.Markdown("# β¨ Prompt Enhancer Trainer + Inference Playground")
|
| 257 |
+
gr.Markdown("Train, test, and debug your LoRA-enhanced Gemma model easily.")
|
| 258 |
+
|
| 259 |
+
with gr.Tabs():
|
| 260 |
+
# -------------------------------
|
| 261 |
+
# 1οΈβ£ TRAINING TAB
|
| 262 |
+
# -------------------------------
|
| 263 |
+
with gr.Tab("Train Model"):
|
| 264 |
+
with gr.Row():
|
| 265 |
+
base_model = gr.Textbox(label="Base Model", value="google/gemma-2b-it")
|
| 266 |
+
dataset_path = gr.Textbox(label="Dataset Folder (Path)")
|
| 267 |
+
repo_id = gr.Textbox(label="Upload HF Repo (optional)", placeholder="username/my-enhancer-model")
|
| 268 |
+
|
| 269 |
+
with gr.Row():
|
| 270 |
+
output_dir = gr.Textbox(label="Local Output Directory", value="/tmp/prompt-enhancer")
|
| 271 |
+
train_btn = gr.Button("π Start Training")
|
| 272 |
+
|
| 273 |
+
train_log = gr.Textbox(label="Training Log", lines=20)
|
| 274 |
+
|
| 275 |
+
def train_model_ui(base_model, dataset_path, repo_id, output_dir):
|
| 276 |
+
return train_model(base_model, dataset_path, repo_id, output_dir)
|
| 277 |
+
|
| 278 |
+
train_btn.click(
|
| 279 |
+
train_model_ui,
|
| 280 |
+
inputs=[base_model, dataset_path, repo_id, output_dir],
|
| 281 |
+
outputs=[train_log],
|
|
|
|
|
|
|
|
|
|
| 282 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
+
# -------------------------------
|
| 285 |
+
# 2οΈβ£ INFERENCE TAB (CPU)
|
| 286 |
+
# -------------------------------
|
| 287 |
+
with gr.Tab("Inference (CPU Mode)"):
|
| 288 |
+
with gr.Row():
|
| 289 |
+
model_repo = gr.Textbox(label="HF Model Repo", value="gokaygokay/prompt-enhancer-gemma-3-270m-it")
|
| 290 |
+
user_prompt = gr.Textbox(label="Enter a short prompt", placeholder="a cat sitting on a chair")
|
| 291 |
+
gen_btn = gr.Button("π§ Generate Enhanced Prompt")
|
| 292 |
+
|
| 293 |
+
result_box = gr.Textbox(label="Enhanced Prompt", lines=10)
|
| 294 |
+
|
| 295 |
+
def run_inference(model_repo, user_prompt):
|
| 296 |
+
import torch
|
| 297 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 298 |
+
|
| 299 |
+
device = "cpu"
|
| 300 |
+
model = AutoModelForCausalLM.from_pretrained(model_repo, torch_dtype=torch.float32, device_map={"": device})
|
| 301 |
+
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
| 302 |
+
|
| 303 |
+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map={"": device})
|
| 304 |
+
|
| 305 |
+
messages = [
|
| 306 |
+
{"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
|
| 307 |
+
{"role": "user", "content": user_prompt},
|
| 308 |
+
]
|
| 309 |
+
|
| 310 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 311 |
+
output = pipe(prompt, max_new_tokens=256)
|
| 312 |
+
return output[0]['generated_text']
|
| 313 |
+
|
| 314 |
+
gen_btn.click(run_inference, inputs=[model_repo, user_prompt], outputs=[result_box])
|
| 315 |
+
|
| 316 |
+
# -------------------------------
|
| 317 |
+
# 3οΈβ£ SHOW TRAINABLE PARAMS TAB
|
| 318 |
+
# -------------------------------
|
| 319 |
+
with gr.Tab("Show Trainable Params"):
|
| 320 |
+
gr.Markdown("### π§© View Trainable Parameters in Your LoRA-Enhanced Model")
|
| 321 |
+
|
| 322 |
+
with gr.Row():
|
| 323 |
+
base_model_name = gr.Textbox(label="Base Model", value="google/gemma-2b-it")
|
| 324 |
+
check_btn = gr.Button("π Show Trainable Layers")
|
| 325 |
+
|
| 326 |
+
param_output = gr.Textbox(label="Trainable Parameters Info", lines=25)
|
| 327 |
+
|
| 328 |
+
def show_trainable_layers(base_model_name):
|
| 329 |
+
import torch
|
| 330 |
+
from peft import get_peft_model, LoraConfig
|
| 331 |
+
from transformers import AutoModelForCausalLM
|
| 332 |
+
|
| 333 |
+
model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
| 334 |
+
config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
|
| 335 |
+
model = get_peft_model(model, config)
|
| 336 |
+
model.print_trainable_parameters()
|
| 337 |
+
return (
|
| 338 |
+
"Each 'Adapter (90)' means 90 LoRA layers (pairs of A/B matrices) were injected.\n\n"
|
| 339 |
+
"π§ These typically correspond to:\n"
|
| 340 |
+
"- q_proj, k_proj, v_proj β Query, Key, Value projections\n"
|
| 341 |
+
"- o_proj or out_proj β Output of attention\n"
|
| 342 |
+
"- gate_proj, up_proj, down_proj β Feed-forward layers\n\n"
|
| 343 |
+
"π‘ So, 'Adapter (90)' = 90 target submodules were wrapped with LoRA.\n\n"
|
| 344 |
+
"Would you like to print them all? Here's how:\n\n"
|
| 345 |
+
"```python\n"
|
| 346 |
+
"for name, module in model.named_modules():\n"
|
| 347 |
+
" if 'lora' in name.lower():\n"
|
| 348 |
+
" print(name)\n"
|
| 349 |
+
"```\n"
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
check_btn.click(show_trainable_layers, inputs=[base_model_name], outputs=[param_output])
|
| 353 |
+
|
| 354 |
+
# -------------------------------
|
| 355 |
+
# 4οΈβ£ CODE DEBUG TAB
|
| 356 |
+
# -------------------------------
|
| 357 |
+
with gr.Tab("Code Debug"):
|
| 358 |
+
gr.Markdown("### π§© Code Debug β Understand What's Happening Line by Line")
|
| 359 |
+
debug_md = gr.Markdown(
|
| 360 |
+
"""
|
| 361 |
+
#### π§° Step-by-Step Breakdown
|
| 362 |
+
Below shows what each major step does internally during training:
|
| 363 |
+
|
| 364 |
+
1. **`f"[INFO] Loading base model: {base_model}"`**
|
| 365 |
+
β Logs which model is being loaded (e.g., `google/gemma-2b-it`)
|
| 366 |
+
|
| 367 |
+
2. **`AutoModelForCausalLM.from_pretrained(base_model)`**
|
| 368 |
+
β Downloads the base Gemma model weights and tokenizer.
|
| 369 |
+
|
| 370 |
+
3. **`get_peft_model(model, config)`**
|
| 371 |
+
β Wraps the model with LoRA. Injects adapters into `q_proj`, `k_proj`, `v_proj`, etc.
|
| 372 |
+
|
| 373 |
+
4. **Expected console output:**
|
| 374 |
+
[INFO] Loading base model: google/gemma-2b-it
|
| 375 |
+
[INFO] Preparing dataset...
|
| 376 |
+
[INFO] Injecting LoRA adapters...
|
| 377 |
+
trainable params: 3.5M || all params: 270M || trainable%: 1.3%
|
| 378 |
+
|
| 379 |
+
5. **`trainer.train()`**
|
| 380 |
+
β Starts training loop, showing tqdm progress bars per epoch.
|
| 381 |
|
| 382 |
+
6. **`upload_file(...)`**
|
| 383 |
+
β Uploads all model files to your chosen HF repo (if specified).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
---
|
| 386 |
|
| 387 |
+
#### π What βAdapter (90)β Means
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
+
When you initialize LoRA on Gemma, it finds **90 target layers** that match
|
| 390 |
+
typical names like:
|
| 391 |
+
- `q_proj`, `k_proj`, `v_proj`
|
| 392 |
+
- `o_proj`
|
| 393 |
+
- `gate_proj`, `up_proj`, `down_proj`
|
| 394 |
|
| 395 |
+
Each layer gets small trainable matrices **(A, B)** injected.
|
| 396 |
+
Hence you see:
|
| 397 |
+
> **Adapter (90)** β *90 modules modified by LoRA.*
|
|
|
|
| 398 |
|
| 399 |
+
You can list them in your own model like this:
|
| 400 |
|
| 401 |
```python
|
| 402 |
+
for name, module in model.named_modules():
|
| 403 |
+
if "lora" in name.lower():
|
| 404 |
+
print(name)
|
| 405 |
+
"""
|
| 406 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
return demo
|
| 409 |
|
|
|
|
| 410 |
|
| 411 |
|
| 412 |
|