S-Dreamer commited on
Commit
33a206b
·
verified ·
1 Parent(s): f0292cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -18
app.py CHANGED
@@ -1,25 +1,104 @@
 
 
 
 
1
  import gradio as gr
2
- from trainer import run_finetune
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def start_training(base_model, dataset_file, epochs):
6
- output = run_finetune(
7
- base_model=base_model,
8
- dataset_path=dataset_file.name,
9
- epochs=int(epochs),
10
  )
11
- return output
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- ui = gr.Interface(
15
- fn=start_training,
16
- inputs=[
17
- gr.Textbox(value="distilbert-base-uncased", label="Base model"),
18
- gr.File(label="Dataset (jsonl)"),
19
- gr.Number(value=3, label="Epochs")
20
- ],
21
- outputs=gr.Textbox(),
22
- title="HuggingFace Fine-Tuning Space"
23
- )
24
 
25
- ui.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+
5
  import gradio as gr
 
6
 
7
+ from src.train import finetune_lora
8
+ from src.infer import load_generator, generate_text
9
+
10
+ def _default_output_root() -> Path:
11
+ # On Spaces, /data exists if Persistent Storage is enabled.
12
+ # Otherwise fall back to repo-local outputs/.
13
+ return Path("/data/outputs") if Path("/data").exists() else Path("outputs")
14
+
15
+ def run_train(
16
+ base_model: str,
17
+ dataset_id: str,
18
+ text_column: str,
19
+ max_train_samples: int,
20
+ max_steps: int,
21
+ lr: float,
22
+ batch_size: int,
23
+ lora_r: int,
24
+ lora_alpha: int,
25
+ lora_dropout: float,
26
+ ):
27
+ out_root = _default_output_root()
28
+ run_id = time.strftime("%Y%m%d-%H%M%S")
29
+ out_dir = out_root / run_id
30
+ out_dir.mkdir(parents=True, exist_ok=True)
31
+
32
+ status = finetune_lora(
33
+ base_model=base_model.strip(),
34
+ dataset_id=dataset_id.strip(),
35
+ text_column=text_column.strip(),
36
+ output_dir=str(out_dir),
37
+ max_train_samples=max_train_samples,
38
+ max_steps=max_steps,
39
+ learning_rate=lr,
40
+ batch_size=batch_size,
41
+ lora_r=lora_r,
42
+ lora_alpha=lora_alpha,
43
+ lora_dropout=lora_dropout,
44
+ )
45
 
46
+ adapter_path = out_dir / "adapter"
47
+ return (
48
+ f"Done.\n\nSaved to: {out_dir}\n\n{status}",
49
+ str(adapter_path) if adapter_path.exists() else None,
50
+ str(out_dir),
51
  )
 
52
 
53
+ def run_generate(base_model: str, adapter_dir: str, prompt: str, max_new_tokens: int, temperature: float):
54
+ gen = load_generator(base_model.strip(), adapter_dir.strip())
55
+ return generate_text(gen, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
56
+
57
+ with gr.Blocks(title="Fine-tune Pipeline (Docker)") as demo:
58
+ gr.Markdown("# Fine-tuning pipeline (LoRA) — Docker Space")
59
+
60
+ with gr.Tab("Train"):
61
+ base_model = gr.Textbox(value="sshleifer/tiny-gpt2", label="Base model (HF Hub id)")
62
+ dataset_id = gr.Textbox(value="karpathy/tiny_shakespeare", label="Dataset (HF Hub id)")
63
+ text_column = gr.Textbox(value="text", label="Text column")
64
+
65
+ with gr.Row():
66
+ max_train_samples = gr.Number(value=2000, precision=0, label="Max train samples")
67
+ max_steps = gr.Number(value=100, precision=0, label="Max steps")
68
+
69
+ with gr.Row():
70
+ lr = gr.Number(value=2e-4, label="Learning rate")
71
+ batch_size = gr.Number(value=2, precision=0, label="Batch size")
72
+
73
+ with gr.Row():
74
+ lora_r = gr.Number(value=8, precision=0, label="LoRA r")
75
+ lora_alpha = gr.Number(value=16, precision=0, label="LoRA alpha")
76
+ lora_dropout = gr.Number(value=0.05, label="LoRA dropout")
77
+
78
+ train_btn = gr.Button("Start fine-tune")
79
+ train_out = gr.Textbox(lines=10, label="Status")
80
+ adapter_file = gr.File(label="Adapter folder (download)")
81
+ out_dir_box = gr.Textbox(label="Output directory")
82
+
83
+ train_btn.click(
84
+ fn=run_train,
85
+ inputs=[base_model, dataset_id, text_column, max_train_samples, max_steps, lr, batch_size, lora_r, lora_alpha, lora_dropout],
86
+ outputs=[train_out, adapter_file, out_dir_box],
87
+ queue=True,
88
+ )
89
+
90
+ with gr.Tab("Generate"):
91
+ base_model2 = gr.Textbox(value="sshleifer/tiny-gpt2", label="Base model (must match training)")
92
+ adapter_dir = gr.Textbox(placeholder="Paste the output adapter dir path (e.g., outputs/2026.../adapter)", label="Adapter directory")
93
+ prompt = gr.Textbox(value="To be, or not to be,", lines=3, label="Prompt")
94
+
95
+ with gr.Row():
96
+ max_new_tokens = gr.Slider(16, 256, value=80, step=1, label="Max new tokens")
97
+ temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature")
98
+
99
+ gen_btn = gr.Button("Generate")
100
+ gen_out = gr.Textbox(lines=10, label="Output")
101
 
102
+ gen_btn.click(fn=run_generate, inputs=[base_model2, adapter_dir, prompt, max_new_tokens, temperature], outputs=[gen_out])
 
 
 
 
 
 
 
 
 
103
 
104
+ demo.launch()