Spaces:
Running
Running
Update app_gpu.py
Browse files- app_gpu.py +44 -54
app_gpu.py
CHANGED
|
@@ -291,7 +291,7 @@ def run_ui():
|
|
| 291 |
outputs=[logs],
|
| 292 |
queue=True)
|
| 293 |
|
| 294 |
-
# ---------------- Inference Tab ----------------
|
| 295 |
with gr.Tab("Inference (CPU)"):
|
| 296 |
inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
|
| 297 |
inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
|
|
@@ -305,63 +305,53 @@ def run_ui():
|
|
| 305 |
|
| 306 |
# ---------------- Code Explain Tab ----------------
|
| 307 |
with gr.Tab("Code Explain"):
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
- Inference: Expand short prompts using LoRA.
|
| 351 |
-
- Code Explain: This simulation showing internal workflow & parameter effects.
|
| 352 |
-
|
| 353 |
-
**Simulated Training Logs:**\n{logs}
|
| 354 |
-
"""
|
| 355 |
-
return explanation
|
| 356 |
-
|
| 357 |
-
explain_btn = gr.Button("📝 Show Code Explanation & Logs")
|
| 358 |
-
explain_btn.click(fn=explain_code,
|
| 359 |
-
inputs=[model_explain, lora_rank, lora_alpha, epochs],
|
| 360 |
-
outputs=[logs_out])
|
| 361 |
|
| 362 |
return demo
|
| 363 |
|
| 364 |
|
|
|
|
| 365 |
|
| 366 |
|
| 367 |
|
|
|
|
| 291 |
outputs=[logs],
|
| 292 |
queue=True)
|
| 293 |
|
| 294 |
+
# ---------------- Inference (CPU) Tab ----------------
|
| 295 |
with gr.Tab("Inference (CPU)"):
|
| 296 |
inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
|
| 297 |
inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
|
|
|
|
| 305 |
|
| 306 |
# ---------------- Code Explain Tab ----------------
|
| 307 |
with gr.Tab("Code Explain"):
|
| 308 |
+
explain_md = gr.Markdown("""
|
| 309 |
+
### Universal LoRA Trainer & Inference - Code Explanation
|
| 310 |
+
|
| 311 |
+
#### 1. Imports
|
| 312 |
+
- `spaces, os, torch, gradio, pandas, numpy`: utilities, tensor ops, UI, and data handling.
|
| 313 |
+
- `peft (LoraConfig, get_peft_model)`: LoRA adapter integration.
|
| 314 |
+
- `accelerate (Accelerator)`: device placement, mixed precision, distributed training.
|
| 315 |
+
- `huggingface_hub`: upload LoRA weights to Hugging Face.
|
| 316 |
+
- `transformers (optional)`: only for Gemma LLM.
|
| 317 |
+
|
| 318 |
+
#### 2. Dataset
|
| 319 |
+
- `MediaTextDataset`: Loads CSV/Parquet or HF dataset, extracts `short_prompt` and `long_prompt`.
|
| 320 |
+
- Handles batched access and missing columns.
|
| 321 |
+
|
| 322 |
+
#### 3. Model Loading
|
| 323 |
+
- `load_pipeline_auto`: Loads Gemma tokenizer + model in float16/32.
|
| 324 |
+
- `find_target_modules`: Detects Linear layers (Q/K/V) for LoRA injection.
|
| 325 |
+
|
| 326 |
+
#### 4. LoRA Training
|
| 327 |
+
- LoRA formula: `W_eff = W + alpha * B @ A`
|
| 328 |
+
- `r` = low-rank dimension, `alpha` = scaling factor.
|
| 329 |
+
- Only trains LoRA matrices, main model frozen.
|
| 330 |
+
- Efficient memory & compute, streams logs.
|
| 331 |
+
- Supports uploading trained LoRA to Hugging Face Hub.
|
| 332 |
+
|
| 333 |
+
#### 5. CPU Inference
|
| 334 |
+
- Loads base Gemma model on CPU (float32).
|
| 335 |
+
- Loads LoRA with `PeftModel.from_pretrained`.
|
| 336 |
+
- Optionally merges LoRA into base model.
|
| 337 |
+
- Generates long prompt using `generate()` with top-p/top-k sampling.
|
| 338 |
+
|
| 339 |
+
#### 6. LoRA Internals
|
| 340 |
+
- LoRA injects trainable matrices `A` & `B` into selected Linear layers.
|
| 341 |
+
- Q/K/V matrices in attention updated as `Q_new = Q + alpha*B@A`.
|
| 342 |
+
- Efficient: `r << hidden_size`, only small matrices trained.
|
| 343 |
+
|
| 344 |
+
#### 7. Gradio UI
|
| 345 |
+
- Train Tab: configure model, dataset, LoRA, HF repo.
|
| 346 |
+
- Inference Tab: short prompt → expanded long prompt.
|
| 347 |
+
- Code Explain Tab: shows detailed explanation and simulated logs.
|
| 348 |
+
""")
|
| 349 |
+
explain_md.render()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
return demo
|
| 352 |
|
| 353 |
|
| 354 |
+
|
| 355 |
|
| 356 |
|
| 357 |
|