rahul7star commited on
Commit
c05048d
·
verified ·
1 Parent(s): d42c09b

Update app_gpu.py

Browse files
Files changed (1) hide show
  1. app_gpu.py +64 -1
app_gpu.py CHANGED
@@ -300,7 +300,70 @@ def run_ui():
300
  inputs=[inf_base_model, inf_lora_repo, short_prompt],
301
  outputs=[long_prompt_out])
302
 
303
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  if __name__ == "__main__":
306
  run_ui().launch(server_name="0.0.0.0", server_port=7860, share=True)
 
300
  inputs=[inf_base_model, inf_lora_repo, short_prompt],
301
  outputs=[long_prompt_out])
302
 
303
+
304
+ # ---------------- Code Explain Tab ----------------
305
+ with gr.Tab("Code Explain"):
306
+ explain_md = gr.Markdown("""
307
+ ### Universal LoRA Trainer & Inference - Code Explanation
308
+
309
+ #### 1. Imports
310
+ - **spaces, os, torch, gradio, pandas, numpy**: General utilities, tensor ops, UI, and data handling.
311
+ - **peft (LoraConfig, get_peft_model)**: Handles LoRA adapters and integration into base model.
312
+ - **accelerate (Accelerator)**: Simplifies device placement, mixed precision, and distributed training.
313
+ - **huggingface_hub**: Upload LoRA weights to HF Hub.
314
+ - **transformers (optional)**: Used if base model is a Hugging Face LLM (Gemma).
315
+
316
+ #### 2. Dataset
317
+ - **MediaTextDataset**: Loads CSV/Parquet or HF dataset, extracts `short_prompt` and `long_prompt`.
318
+ - Handles batched access and fallback for missing columns.
319
+
320
+ #### 3. Model Loading
321
+ - `load_pipeline_auto`: Loads Gemma tokenizer + model in float16/32 depending on device.
322
+ - `find_target_modules`: Detects which Linear layers to apply LoRA (Q/K/V projections).
323
+
324
+ #### 4. LoRA Training (`train_lora_stream`)
325
+ W_eff = W + alpha * B @ A
326
+ - **LoRA Config**:
327
+ - `r` is low-rank dimension.
328
+ - `alpha` scales LoRA updates.
329
+ - Targets Q/K/V or other Linear layers in attention.
330
+ - **Training**:
331
+ - Dataset is wrapped in DataLoader.
332
+ - LoRA module + optimizer prepared with Accelerator.
333
+ - Forward pass computes loss (cross-entropy).
334
+ - Backprop applied only to LoRA parameters (efficient).
335
+ - Logs streamed for each step.
336
+ - **Upload**: Saves LoRA and pushes to HF Hub.
337
+
338
+ #### 5. CPU Inference (`generate_long_prompt_cpu`)
339
+ - Loads base Gemma model in CPU (float32).
340
+ - Loads LoRA weights with `PeftModel.from_pretrained`.
341
+ - Optionally merges LoRA into base to simplify runtime.
342
+ - Tokenizes short prompt and generates expanded prompt using `generate()` with top-p/top-k sampling.
343
+
344
+ #### 6. LoRA Internals
345
+ - LoRA injects trainable matrices `A` and `B` into selected Linear layers (usually Q/K/V in attention):
346
+ - `Query, Key, Value (Q/K/V)` are used in attention:
347
+ ```
348
+ Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
349
+ ```
350
+ - LoRA updates `Q/K/V` with `alpha * B @ A`, keeping main model frozen.
351
+ - Efficient: only small low-rank matrices are trained (`r << hidden_size`), reducing memory & compute.
352
+ - Other modules LoRA can target: `out_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`.
353
+
354
+ #### 7. Gradio UI
355
+ - **Train Tab**: User inputs for model, dataset, LoRA params, and HF repo.
356
+ - **Inference Tab**: Short prompt → expanded long prompt using LoRA on CPU.
357
+ - **Code Explain Tab**: Interactive Markdown explaining code logic & LoRA internals.
358
+ """)
359
+ explain_md.render()
360
+
361
+ return demo
362
+
363
+
364
+
365
+
366
+
367
 
368
  if __name__ == "__main__":
369
  run_ui().launch(server_name="0.0.0.0", server_port=7860, share=True)