aaravriyer193 commited on
Commit
e4e6387
Β·
verified Β·
1 Parent(s): 6e2167d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -0
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ TrainingArguments,
9
+ Trainer,
10
+ DataCollatorForLanguageModeling,
11
+ )
12
+ from peft import LoraConfig, get_peft_model, TaskType
13
+ import threading
14
+
15
+ # ── Globals ──────────────────────────────────────────────────────────────────
16
+ training_log = []
17
+ training_thread = None
18
+ stop_flag = threading.Event()
19
+
20
+
21
+ def log(msg: str):
22
+ training_log.append(msg)
23
+ print(msg)
24
+
25
+
26
+ # ── Core training function ────────────────────────────────────────────────────
27
+ def run_finetuning(
28
+ model_name: str,
29
+ dataset_name: str,
30
+ dataset_config: str,
31
+ text_column: str,
32
+ num_train_epochs: int,
33
+ per_device_batch_size: int,
34
+ learning_rate: float,
35
+ max_seq_length: int,
36
+ use_lora: bool,
37
+ lora_r: int,
38
+ output_dir: str,
39
+ ):
40
+ global training_log, stop_flag
41
+ training_log = []
42
+ stop_flag.clear()
43
+
44
+ try:
45
+ log(f"πŸ”§ Loading tokenizer: {model_name}")
46
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+
50
+ log(f"πŸ“¦ Loading model: {model_name}")
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ model_name,
53
+ torch_dtype=torch.float32, # CPU-safe
54
+ low_cpu_mem_usage=True,
55
+ )
56
+
57
+ if use_lora:
58
+ log(f"⚑ Applying LoRA (r={lora_r}) ...")
59
+ lora_config = LoraConfig(
60
+ task_type=TaskType.CAUSAL_LM,
61
+ r=lora_r,
62
+ lora_alpha=lora_r * 2,
63
+ lora_dropout=0.05,
64
+ bias="none",
65
+ target_modules=["c_attn", "c_proj", "q_proj", "v_proj", "k_proj", "o_proj"],
66
+ )
67
+ model = get_peft_model(model, lora_config)
68
+ trainable, total = model.get_nb_trainable_parameters()
69
+ log(f" Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
70
+
71
+ log(f"πŸ“‚ Loading dataset: {dataset_name}" + (f" ({dataset_config})" if dataset_config else ""))
72
+ ds_kwargs = {"split": "train", "trust_remote_code": True}
73
+ if dataset_config.strip():
74
+ dataset = load_dataset(dataset_name, dataset_config, **ds_kwargs)
75
+ else:
76
+ dataset = load_dataset(dataset_name, **ds_kwargs)
77
+
78
+ # Take a small sample for demo / CPU friendliness
79
+ dataset = dataset.select(range(min(500, len(dataset))))
80
+ log(f" Using {len(dataset)} training samples")
81
+
82
+ def tokenize(batch):
83
+ texts = [str(t) for t in batch[text_column]]
84
+ return tokenizer(
85
+ texts,
86
+ truncation=True,
87
+ max_length=max_seq_length,
88
+ padding="max_length",
89
+ )
90
+
91
+ log("πŸ”€ Tokenizing dataset ...")
92
+ tokenized = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
93
+ tokenized.set_format("torch")
94
+
95
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
96
+
97
+ training_args = TrainingArguments(
98
+ output_dir=output_dir,
99
+ num_train_epochs=num_train_epochs,
100
+ per_device_train_batch_size=per_device_batch_size,
101
+ learning_rate=learning_rate,
102
+ logging_steps=5,
103
+ save_strategy="epoch",
104
+ fp16=False,
105
+ bf16=False,
106
+ no_cuda=True,
107
+ report_to="none",
108
+ disable_tqdm=False,
109
+ )
110
+
111
+ class LogCallback(torch.utils.data.Dataset):
112
+ pass
113
+
114
+ from transformers import TrainerCallback
115
+
116
+ class StreamLogger(TrainerCallback):
117
+ def on_log(self, args, state, control, logs=None, **kwargs):
118
+ if logs:
119
+ step = state.global_step
120
+ loss = logs.get("loss", "β€”")
121
+ lr = logs.get("learning_rate", "β€”")
122
+ log(f" step {step:>4} | loss: {loss} | lr: {lr}")
123
+
124
+ def on_epoch_end(self, args, state, control, **kwargs):
125
+ log(f"βœ… Epoch {int(state.epoch)} complete")
126
+ if stop_flag.is_set():
127
+ control.should_training_stop = True
128
+
129
+ trainer = Trainer(
130
+ model=model,
131
+ args=training_args,
132
+ train_dataset=tokenized,
133
+ data_collator=data_collator,
134
+ callbacks=[StreamLogger()],
135
+ )
136
+
137
+ log("πŸš€ Starting training ...")
138
+ trainer.train()
139
+
140
+ log(f"πŸ’Ύ Saving model to: {output_dir}")
141
+ trainer.save_model(output_dir)
142
+ tokenizer.save_pretrained(output_dir)
143
+ log("πŸŽ‰ Fine-tuning complete!")
144
+
145
+ except Exception as e:
146
+ log(f"❌ Error: {e}")
147
+
148
+
149
+ # ── Gradio helpers ────────────────────────────────────────────────────────────
150
+ def start_training(
151
+ model_name, dataset_name, dataset_config, text_column,
152
+ num_epochs, batch_size, learning_rate, max_seq_len,
153
+ use_lora, lora_r, output_dir,
154
+ ):
155
+ global training_thread
156
+ if training_thread and training_thread.is_alive():
157
+ return "⚠️ Training already running!"
158
+
159
+ training_thread = threading.Thread(
160
+ target=run_finetuning,
161
+ args=(
162
+ model_name, dataset_name, dataset_config, text_column,
163
+ num_epochs, batch_size, learning_rate, max_seq_len,
164
+ use_lora, lora_r, output_dir,
165
+ ),
166
+ daemon=True,
167
+ )
168
+ training_thread.start()
169
+ return "Training started! Check the log below."
170
+
171
+
172
+ def stop_training():
173
+ stop_flag.set()
174
+ return "πŸ›‘ Stop signal sent."
175
+
176
+
177
+ def get_logs():
178
+ return "\n".join(training_log) if training_log else "No logs yet..."
179
+
180
+
181
+ def is_running():
182
+ return "🟒 Running" if (training_thread and training_thread.is_alive()) else "⚫ Idle"
183
+
184
+
185
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
186
+ with gr.Blocks(
187
+ title="LLM Fine-Tuner",
188
+ theme=gr.themes.Base(
189
+ primary_hue="emerald",
190
+ neutral_hue="zinc",
191
+ font=gr.themes.GoogleFont("JetBrains Mono"),
192
+ ),
193
+ css="""
194
+ .container { max-width: 900px; margin: auto; }
195
+ .gr-button-primary { background: #10b981 !important; }
196
+ footer { display: none !important; }
197
+ """,
198
+ ) as demo:
199
+ gr.Markdown(
200
+ """
201
+ # πŸ€– LLM Fine-Tuner
202
+ Fine-tune small language models on Hugging Face datasets β€” CPU-friendly with LoRA support.
203
+ """
204
+ )
205
+
206
+ with gr.Row():
207
+ with gr.Column(scale=1):
208
+ gr.Markdown("### 🧠 Model")
209
+ model_name = gr.Dropdown(
210
+ choices=[
211
+ "distilgpt2",
212
+ "gpt2",
213
+ "facebook/opt-125m",
214
+ "EleutherAI/pythia-70m",
215
+ "EleutherAI/pythia-160m",
216
+ "microsoft/phi-1_5",
217
+ ],
218
+ value="distilgpt2",
219
+ label="Base Model",
220
+ allow_custom_value=True,
221
+ )
222
+
223
+ gr.Markdown("### πŸ“¦ Dataset")
224
+ dataset_name = gr.Textbox(value="wikitext", label="Dataset Name (HF Hub)")
225
+ dataset_config = gr.Textbox(value="wikitext-2-raw-v1", label="Dataset Config (optional)")
226
+ text_column = gr.Textbox(value="text", label="Text Column")
227
+
228
+ with gr.Column(scale=1):
229
+ gr.Markdown("### βš™οΈ Training")
230
+ num_epochs = gr.Slider(1, 10, value=1, step=1, label="Epochs")
231
+ batch_size = gr.Slider(1, 16, value=2, step=1, label="Batch Size")
232
+ learning_rate = gr.Number(value=2e-4, label="Learning Rate")
233
+ max_seq_len = gr.Slider(32, 512, value=128, step=32, label="Max Sequence Length")
234
+ output_dir = gr.Textbox(value="./finetuned-model", label="Output Directory")
235
+
236
+ gr.Markdown("### ⚑ LoRA (recommended for CPU)")
237
+ use_lora = gr.Checkbox(value=True, label="Use LoRA")
238
+ lora_r = gr.Slider(4, 64, value=8, step=4, label="LoRA Rank (r)")
239
+
240
+ with gr.Row():
241
+ start_btn = gr.Button("πŸš€ Start Fine-Tuning", variant="primary")
242
+ stop_btn = gr.Button("πŸ›‘ Stop", variant="secondary")
243
+ status_btn = gr.Button("πŸ”„ Refresh Status")
244
+
245
+ status_box = gr.Textbox(label="Status", value="⚫ Idle", interactive=False)
246
+ log_box = gr.Textbox(
247
+ label="Training Log",
248
+ lines=20,
249
+ max_lines=30,
250
+ interactive=False,
251
+ placeholder="Logs will appear here once training starts...",
252
+ )
253
+
254
+ start_btn.click(
255
+ fn=start_training,
256
+ inputs=[
257
+ model_name, dataset_name, dataset_config, text_column,
258
+ num_epochs, batch_size, learning_rate, max_seq_len,
259
+ use_lora, lora_r, output_dir,
260
+ ],
261
+ outputs=status_box,
262
+ )
263
+
264
+ stop_btn.click(fn=stop_training, outputs=status_box)
265
+ status_btn.click(fn=lambda: (is_running(), get_logs()), outputs=[status_box, log_box])
266
+
267
+ gr.Markdown(
268
+ """
269
+ ---
270
+ **Tips:**
271
+ - `distilgpt2` (82M) is the best starting point on CPU.
272
+ - Enable **LoRA** to drastically reduce memory and training time.
273
+ - Keep **Max Sequence Length ≀ 128** and **Batch Size = 1–2** on free CPU tier.
274
+ - The dataset is capped at **500 samples** for CPU-friendly runs β€” edit the code to increase.
275
+ """
276
+ )
277
+
278
+ if __name__ == "__main__":
279
+ demo.launch(share=False)