rahul7star commited on
Commit
9985acb
Β·
verified Β·
1 Parent(s): 2dca78b

Update app_gpu.py

Browse files
Files changed (1) hide show
  1. app_gpu.py +141 -162
app_gpu.py CHANGED
@@ -252,182 +252,161 @@ def generate_long_prompt_cpu(base_model, lora_repo, short_prompt, max_length=200
252
  import gradio as gr
253
 
254
  def run_ui():
255
- with gr.Blocks() as demo:
256
- gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer & Inference")
257
-
258
- # ---------------- Train LoRA Tab ----------------
259
- with gr.Tab("Train LoRA"):
260
- with gr.Row():
261
- base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
262
- dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
263
- csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
264
- short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
265
- long_col = gr.Textbox(label="Long prompt column", value="long_prompt")
266
- repo = gr.Textbox(label="HF repo to upload LoRA", value="rahul7star/gemma-3-270m-ccebc0")
267
-
268
- with gr.Row():
269
- batch_size = gr.Number(value=1, label="Batch size")
270
- num_workers = gr.Number(value=0, label="DataLoader num_workers")
271
- r = gr.Number(value=8, label="LoRA rank")
272
- a = gr.Number(value=16, label="LoRA alpha")
273
- ep = gr.Number(value=1, label="Epochs")
274
- lr = gr.Number(value=1e-4, label="Learning rate")
275
- max_records = gr.Number(value=1000, label="Max training records")
276
-
277
- logs = gr.Textbox(label="Logs (streaming)", lines=25)
278
-
279
- def launch_train(bm, ds, csv, sc, lc, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_):
280
- gen = train_lora_stream(
281
- bm, ds, csv, [sc, lc],
282
- epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_),
283
- batch_size=int(batch), num_workers=int(num_w),
284
- max_train_records=int(max_rec), hf_repo_id=repo_
285
  )
286
- for item in gen:
287
- yield item
288
-
289
- btn = gr.Button("πŸš€ Start Training")
290
- btn.click(
291
- fn=launch_train,
292
- inputs=[
293
- base_model, dataset, csvname, short_col, long_col,
294
- batch_size, num_workers, r, a, ep, lr, max_records, repo
295
- ],
296
- outputs=[logs],
297
- queue=True
298
- )
299
-
300
- # ---------------- Inference (CPU) Tab ----------------
301
- with gr.Tab("Inference (CPU)"):
302
- inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
303
- inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
304
- short_prompt = gr.Textbox(label="Short prompt")
305
- long_prompt_out = gr.Textbox(label="Generated long prompt", lines=5)
306
-
307
- inf_btn = gr.Button("πŸ“ Generate Long Prompt")
308
- inf_btn.click(
309
- fn=generate_long_prompt_cpu,
310
- inputs=[inf_base_model, inf_lora_repo, short_prompt],
311
- outputs=[long_prompt_out]
312
- )
313
-
314
- # ---------------- Code Explain Tab ----------------
315
- with gr.Tab("Code Explain"):
316
- explain_md = gr.Markdown("""
317
- ### 🧩 Universal Dynamic LoRA Trainer & Inference β€” Code Explanation
318
-
319
- This project provides an **end-to-end LoRA fine-tuning and inference system** for language models like **Gemma**, built with **Gradio**, **PEFT**, and **Accelerate**.
320
- It supports both **training new LoRAs** and **generating text** with existing ones β€” all in a single interface.
321
-
322
- ---
323
-
324
- #### **1️⃣ Imports Overview**
325
- - **Core libs:** `os`, `torch`, `gradio`, `numpy`, `pandas`
326
- - **Training libs:** `peft` (`LoraConfig`, `get_peft_model`), `accelerate` (`Accelerator`)
327
- - **Modeling:** `transformers` (for Gemma base model)
328
- - **Hub integration:** `huggingface_hub` (for uploading adapters)
329
- - **Spaces:** `spaces` β€” for execution within Hugging Face Spaces
330
-
331
- ---
332
-
333
- #### **2️⃣ Dataset Loading**
334
- - Uses a lightweight **MediaTextDataset** class to load:
335
- - CSV / Parquet files
336
- - or directly from a Hugging Face dataset repo
337
- - Expects two columns:
338
- `short_prompt` β†’ Input text
339
- `long_prompt` β†’ Target expanded text
340
- - Supports batching, missing-column checks, and configurable max record limits.
341
-
342
- ---
343
-
344
- #### **3️⃣ Model Loading & Preparation**
345
- - Loads **Gemma model and tokenizer** via `AutoModelForCausalLM` and `AutoTokenizer`.
346
- - Automatically detects **target modules** (e.g. `q_proj`, `v_proj`) for LoRA injection.
347
- - Supports `float16` or `bfloat16` precision with `Accelerator` for optimal memory usage.
348
-
349
- ---
350
 
351
- #### **4️⃣ LoRA Training Logic**
352
- - Core formula:
353
- \[
354
- W_{eff} = W + \alpha \times (B @ A)
355
- \]
356
- - Only **A** and **B** matrices are trainable; base model weights remain frozen.
357
- - Configurable parameters:
358
- `r` (rank), `alpha` (scaling), `epochs`, `lr`, `batch_size`
359
- - Training logs stream live in the UI, showing step-by-step loss values.
360
- - After training, the adapter is **saved locally** and **uploaded to Hugging Face Hub**.
361
-
362
- ---
363
-
364
- #### **5️⃣ CPU Inference Mode**
365
- - Runs entirely on **CPU**, no GPU required.
366
- - Loads base Gemma model + trained LoRA weights (`PeftModel.from_pretrained`).
367
- - Optionally merges LoRA with base model.
368
- - Expands the short prompt β†’ long descriptive text using standard generation parameters (e.g., top-p / top-k sampling).
369
-
370
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- #### **6️⃣ LoRA Internals Explained**
373
- - LoRA injects low-rank matrices (A, B) into **attention Linear layers**.
374
- - Example:
375
- \[
376
- Q_{new} = Q + \alpha \times (B @ A)
377
- \]
378
- - Significantly reduces training cost:
379
- - Memory: ~1–2% of full model
380
- - Compute: trains faster with minimal GPU load
381
- - Scalable to large models like Gemma 3B / 4B with rank ≀ 16.
382
 
383
  ---
384
 
385
- #### **7️⃣ Gradio UI Structure**
386
- - **Train LoRA Tab:**
387
- Configure model, dataset, LoRA parameters, and upload target.
388
- Press **πŸš€ Start Training** to stream training logs live.
389
 
390
- - **Inference (CPU) Tab:**
391
- Type a short prompt β†’ Generates expanded long-form version via trained LoRA.
 
 
 
392
 
393
- - **Code Explain Tab:**
394
- Detailed breakdown of logic + simulated console output below.
395
-
396
- ---
397
 
398
- ### 🧾 Example Log Simulation
399
 
400
  ```python
401
- print(f"[INFO] Loading base model: {base_model}")
402
- # -> Loads Gemma base model (fp16) on CUDA
403
- # [INFO] Base model google/gemma-3-4b-it loaded successfully
404
-
405
- print(f"[INFO] Preparing dataset from: {dataset_path}")
406
- # -> Loads dataset or CSV file
407
- # [DATA] 980 samples loaded, columns: short_prompt, long_prompt
408
-
409
- print("[INFO] Initializing LoRA configuration...")
410
- # -> Creates LoraConfig(r=8, alpha=16, target_modules=['q_proj', 'v_proj'])
411
- # [CONFIG] LoRA applied to 96 attention layers
412
-
413
- print("[INFO] Starting training loop...")
414
- # [TRAIN] Step 1 | Loss: 2.31
415
- # [TRAIN] Step 50 | Loss: 1.42
416
- # [TRAIN] Step 100 | Loss: 0.91
417
- # [TRAIN] Epoch 1 complete (avg loss: 1.21)
418
-
419
- print("[INFO] Saving LoRA adapter...")
420
- # -> Saves safetensors and config locally
421
-
422
- print(f"[UPLOAD] Pushing adapter to {hf_repo_id}")
423
- # -> Uploads model to Hugging Face Hub
424
- # [UPLOAD] adapter_model.safetensors (67.7 MB)
425
- # [SUCCESS] LoRA uploaded successfully πŸš€
426
- """)
427
 
428
  return demo
429
 
430
-
431
 
432
 
433
 
 
252
  import gradio as gr
253
 
254
  def run_ui():
255
+ with gr.Blocks(title="Prompt Enhancer Trainer + Inference UI") as demo:
256
+ gr.Markdown("# ✨ Prompt Enhancer Trainer + Inference Playground")
257
+ gr.Markdown("Train, test, and debug your LoRA-enhanced Gemma model easily.")
258
+
259
+ with gr.Tabs():
260
+ # -------------------------------
261
+ # 1️⃣ TRAINING TAB
262
+ # -------------------------------
263
+ with gr.Tab("Train Model"):
264
+ with gr.Row():
265
+ base_model = gr.Textbox(label="Base Model", value="google/gemma-2b-it")
266
+ dataset_path = gr.Textbox(label="Dataset Folder (Path)")
267
+ repo_id = gr.Textbox(label="Upload HF Repo (optional)", placeholder="username/my-enhancer-model")
268
+
269
+ with gr.Row():
270
+ output_dir = gr.Textbox(label="Local Output Directory", value="/tmp/prompt-enhancer")
271
+ train_btn = gr.Button("πŸš€ Start Training")
272
+
273
+ train_log = gr.Textbox(label="Training Log", lines=20)
274
+
275
+ def train_model_ui(base_model, dataset_path, repo_id, output_dir):
276
+ return train_model(base_model, dataset_path, repo_id, output_dir)
277
+
278
+ train_btn.click(
279
+ train_model_ui,
280
+ inputs=[base_model, dataset_path, repo_id, output_dir],
281
+ outputs=[train_log],
 
 
 
282
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ # -------------------------------
285
+ # 2️⃣ INFERENCE TAB (CPU)
286
+ # -------------------------------
287
+ with gr.Tab("Inference (CPU Mode)"):
288
+ with gr.Row():
289
+ model_repo = gr.Textbox(label="HF Model Repo", value="gokaygokay/prompt-enhancer-gemma-3-270m-it")
290
+ user_prompt = gr.Textbox(label="Enter a short prompt", placeholder="a cat sitting on a chair")
291
+ gen_btn = gr.Button("🧠 Generate Enhanced Prompt")
292
+
293
+ result_box = gr.Textbox(label="Enhanced Prompt", lines=10)
294
+
295
+ def run_inference(model_repo, user_prompt):
296
+ import torch
297
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
298
+
299
+ device = "cpu"
300
+ model = AutoModelForCausalLM.from_pretrained(model_repo, torch_dtype=torch.float32, device_map={"": device})
301
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
302
+
303
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map={"": device})
304
+
305
+ messages = [
306
+ {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
307
+ {"role": "user", "content": user_prompt},
308
+ ]
309
+
310
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
311
+ output = pipe(prompt, max_new_tokens=256)
312
+ return output[0]['generated_text']
313
+
314
+ gen_btn.click(run_inference, inputs=[model_repo, user_prompt], outputs=[result_box])
315
+
316
+ # -------------------------------
317
+ # 3️⃣ SHOW TRAINABLE PARAMS TAB
318
+ # -------------------------------
319
+ with gr.Tab("Show Trainable Params"):
320
+ gr.Markdown("### 🧩 View Trainable Parameters in Your LoRA-Enhanced Model")
321
+
322
+ with gr.Row():
323
+ base_model_name = gr.Textbox(label="Base Model", value="google/gemma-2b-it")
324
+ check_btn = gr.Button("πŸ” Show Trainable Layers")
325
+
326
+ param_output = gr.Textbox(label="Trainable Parameters Info", lines=25)
327
+
328
+ def show_trainable_layers(base_model_name):
329
+ import torch
330
+ from peft import get_peft_model, LoraConfig
331
+ from transformers import AutoModelForCausalLM
332
+
333
+ model = AutoModelForCausalLM.from_pretrained(base_model_name)
334
+ config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
335
+ model = get_peft_model(model, config)
336
+ model.print_trainable_parameters()
337
+ return (
338
+ "Each 'Adapter (90)' means 90 LoRA layers (pairs of A/B matrices) were injected.\n\n"
339
+ "🧠 These typically correspond to:\n"
340
+ "- q_proj, k_proj, v_proj β†’ Query, Key, Value projections\n"
341
+ "- o_proj or out_proj β†’ Output of attention\n"
342
+ "- gate_proj, up_proj, down_proj β†’ Feed-forward layers\n\n"
343
+ "πŸ’‘ So, 'Adapter (90)' = 90 target submodules were wrapped with LoRA.\n\n"
344
+ "Would you like to print them all? Here's how:\n\n"
345
+ "```python\n"
346
+ "for name, module in model.named_modules():\n"
347
+ " if 'lora' in name.lower():\n"
348
+ " print(name)\n"
349
+ "```\n"
350
+ )
351
+
352
+ check_btn.click(show_trainable_layers, inputs=[base_model_name], outputs=[param_output])
353
+
354
+ # -------------------------------
355
+ # 4️⃣ CODE DEBUG TAB
356
+ # -------------------------------
357
+ with gr.Tab("Code Debug"):
358
+ gr.Markdown("### 🧩 Code Debug β€” Understand What's Happening Line by Line")
359
+ debug_md = gr.Markdown(
360
+ """
361
+ #### 🧰 Step-by-Step Breakdown
362
+ Below shows what each major step does internally during training:
363
+
364
+ 1. **`f"[INFO] Loading base model: {base_model}"`**
365
+ β†’ Logs which model is being loaded (e.g., `google/gemma-2b-it`)
366
+
367
+ 2. **`AutoModelForCausalLM.from_pretrained(base_model)`**
368
+ β†’ Downloads the base Gemma model weights and tokenizer.
369
+
370
+ 3. **`get_peft_model(model, config)`**
371
+ β†’ Wraps the model with LoRA. Injects adapters into `q_proj`, `k_proj`, `v_proj`, etc.
372
+
373
+ 4. **Expected console output:**
374
+ [INFO] Loading base model: google/gemma-2b-it
375
+ [INFO] Preparing dataset...
376
+ [INFO] Injecting LoRA adapters...
377
+ trainable params: 3.5M || all params: 270M || trainable%: 1.3%
378
+
379
+ 5. **`trainer.train()`**
380
+ β†’ Starts training loop, showing tqdm progress bars per epoch.
381
 
382
+ 6. **`upload_file(...)`**
383
+ β†’ Uploads all model files to your chosen HF repo (if specified).
 
 
 
 
 
 
 
 
384
 
385
  ---
386
 
387
+ #### πŸ” What β€œAdapter (90)” Means
 
 
 
388
 
389
+ When you initialize LoRA on Gemma, it finds **90 target layers** that match
390
+ typical names like:
391
+ - `q_proj`, `k_proj`, `v_proj`
392
+ - `o_proj`
393
+ - `gate_proj`, `up_proj`, `down_proj`
394
 
395
+ Each layer gets small trainable matrices **(A, B)** injected.
396
+ Hence you see:
397
+ > **Adapter (90)** β†’ *90 modules modified by LoRA.*
 
398
 
399
+ You can list them in your own model like this:
400
 
401
  ```python
402
+ for name, module in model.named_modules():
403
+ if "lora" in name.lower():
404
+ print(name)
405
+ """
406
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
  return demo
409
 
 
410
 
411
 
412