rahul7star commited on
Commit
b392d21
·
verified ·
1 Parent(s): b2315b7

Update app_gpu.py

Browse files
Files changed (1) hide show
  1. app_gpu.py +44 -54
app_gpu.py CHANGED
@@ -291,7 +291,7 @@ def run_ui():
291
  outputs=[logs],
292
  queue=True)
293
 
294
- # ---------------- Inference Tab ----------------
295
  with gr.Tab("Inference (CPU)"):
296
  inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
297
  inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
@@ -305,63 +305,53 @@ def run_ui():
305
 
306
  # ---------------- Code Explain Tab ----------------
307
  with gr.Tab("Code Explain"):
308
-
309
- def simulate_logs(base_model, r_, a_, ep_):
310
- simulated = [
311
- f"[INFO] Loading base model: {base_model}",
312
- f"[INFO] LoRA configuration: r={r_}, alpha={a_}",
313
- f"[INFO] Epoch {i+1}/{ep_} started..." for i in range(int(ep_))
314
- ]
315
- for ep_idx in range(int(ep_)):
316
- for step in range(1, 6):
317
- simulated.append(f"[DEBUG] Step {step}, Loss: {0.01 * (6-step):.6f}")
318
- simulated.append(f"[INFO] Epoch {ep_idx+1} completed.")
319
- simulated.append("[INFO] LoRA training finished. Ready to upload to HF Hub.")
320
- return "\n".join(simulated)
321
-
322
- model_explain = gr.Textbox(label="Base Model", value="google/gemma-3-4b-it")
323
- lora_rank = gr.Number(label="LoRA rank (r)", value=8)
324
- lora_alpha = gr.Number(label="LoRA alpha", value=16)
325
- epochs = gr.Number(label="Epochs", value=1)
326
- logs_out = gr.Textbox(label="Simulated Logs & Explanation", lines=30)
327
-
328
- def explain_code(model_name, r_, a_, ep_):
329
- logs = simulate_logs(model_name, r_, a_, ep_)
330
- explanation = f"""
331
- ### Universal LoRA Trainer & Inference - Detailed Explanation
332
-
333
- 1. **Imports**: Handles data, tensor ops, LoRA PEFT, HF Hub integration, and optional Transformers.
334
- 2. **Dataset**: `MediaTextDataset` loads short & long prompts from CSV/Parquet/HF.
335
- 3. **Model Loader**: Loads base Gemma model; detects Linear layers (Q/K/V) to apply LoRA.
336
- 4. **LoRA Internals**:
337
- - LoRA injects low-rank matrices `A` and `B` into Q/K/V projections.
338
- - `Effective weight: W_eff = W + alpha * B @ A`
339
- - Only LoRA parameters are trained; main model frozen.
340
- 5. **Training Loop**:
341
- - Forward pass Cross-entropy loss.
342
- - Backprop updates LoRA weights only.
343
- - Accelerator handles device placement & mixed precision.
344
- 6. **CPU Inference**:
345
- - Loads base + LoRA on CPU.
346
- - Merges LoRA optionally to avoid runtime PEFT issues.
347
- - Generates expanded prompt from short prompt.
348
- 7. **Gradio UI Tabs**:
349
- - Train LoRA: Configure training, see live logs.
350
- - Inference: Expand short prompts using LoRA.
351
- - Code Explain: This simulation showing internal workflow & parameter effects.
352
-
353
- **Simulated Training Logs:**\n{logs}
354
- """
355
- return explanation
356
-
357
- explain_btn = gr.Button("📝 Show Code Explanation & Logs")
358
- explain_btn.click(fn=explain_code,
359
- inputs=[model_explain, lora_rank, lora_alpha, epochs],
360
- outputs=[logs_out])
361
 
362
  return demo
363
 
364
 
 
365
 
366
 
367
 
 
291
  outputs=[logs],
292
  queue=True)
293
 
294
+ # ---------------- Inference (CPU) Tab ----------------
295
  with gr.Tab("Inference (CPU)"):
296
  inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
297
  inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
 
305
 
306
  # ---------------- Code Explain Tab ----------------
307
  with gr.Tab("Code Explain"):
308
+ explain_md = gr.Markdown("""
309
+ ### Universal LoRA Trainer & Inference - Code Explanation
310
+
311
+ #### 1. Imports
312
+ - `spaces, os, torch, gradio, pandas, numpy`: utilities, tensor ops, UI, and data handling.
313
+ - `peft (LoraConfig, get_peft_model)`: LoRA adapter integration.
314
+ - `accelerate (Accelerator)`: device placement, mixed precision, distributed training.
315
+ - `huggingface_hub`: upload LoRA weights to Hugging Face.
316
+ - `transformers (optional)`: only for Gemma LLM.
317
+
318
+ #### 2. Dataset
319
+ - `MediaTextDataset`: Loads CSV/Parquet or HF dataset, extracts `short_prompt` and `long_prompt`.
320
+ - Handles batched access and missing columns.
321
+
322
+ #### 3. Model Loading
323
+ - `load_pipeline_auto`: Loads Gemma tokenizer + model in float16/32.
324
+ - `find_target_modules`: Detects Linear layers (Q/K/V) for LoRA injection.
325
+
326
+ #### 4. LoRA Training
327
+ - LoRA formula: `W_eff = W + alpha * B @ A`
328
+ - `r` = low-rank dimension, `alpha` = scaling factor.
329
+ - Only trains LoRA matrices, main model frozen.
330
+ - Efficient memory & compute, streams logs.
331
+ - Supports uploading trained LoRA to Hugging Face Hub.
332
+
333
+ #### 5. CPU Inference
334
+ - Loads base Gemma model on CPU (float32).
335
+ - Loads LoRA with `PeftModel.from_pretrained`.
336
+ - Optionally merges LoRA into base model.
337
+ - Generates long prompt using `generate()` with top-p/top-k sampling.
338
+
339
+ #### 6. LoRA Internals
340
+ - LoRA injects trainable matrices `A` & `B` into selected Linear layers.
341
+ - Q/K/V matrices in attention updated as `Q_new = Q + alpha*B@A`.
342
+ - Efficient: `r << hidden_size`, only small matrices trained.
343
+
344
+ #### 7. Gradio UI
345
+ - Train Tab: configure model, dataset, LoRA, HF repo.
346
+ - Inference Tab: short prompt expanded long prompt.
347
+ - Code Explain Tab: shows detailed explanation and simulated logs.
348
+ """)
349
+ explain_md.render()
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  return demo
352
 
353
 
354
+
355
 
356
 
357