Spaces:
Paused
Paused
File size: 8,852 Bytes
9a0325b bd3c7c9 663a4dd a8621e2 bd3c7c9 a8621e2 bd3c7c9 9a0325b a8621e2 bd3c7c9 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 bd3c7c9 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b bd3c7c9 9a0325b bd3c7c9 9a0325b 663a4dd 9a0325b a8621e2 663a4dd 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b 663a4dd 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b a8621e2 9a0325b bd3c7c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
#!/usr/bin/env python3
"""
Gradio app for DPO training on HuggingFace Space.
Training methods:
- SFT (Supervised Fine-Tuning): Traditional baseline
- DPO (Direct Preference Optimization): Recommended (+26% F1)
"""
import gradio as gr
import spaces
import subprocess
import os
from datasets import load_dataset
# Pre-download RLHF dataset at startup
def download_rlhf_dataset():
"""Download RLHF dataset from Hub and cache locally."""
dataset_path = "data/hf_rlhf_dataset_v1"
if not os.path.exists(dataset_path):
print("📥 Downloading RLHF dataset from Hub...")
dataset = load_dataset("OliverSlivka/itemset-extraction-rlhf-v1")
dataset.save_to_disk(dataset_path)
print(f"✅ Dataset cached to {dataset_path}")
else:
print(f"✅ Dataset already cached: {dataset_path}")
# Download at startup
download_rlhf_dataset()
def run_training(training_method, training_mode, model_size):
"""
Run training with GPU support via @spaces.GPU decorator.
Args:
training_method: "sft" or "dpo"
training_mode: "test" for quick validation, "full" for production
model_size: "3B" or "7B"
"""
# Set HF token from Space secrets
hf_token = os.getenv("HF_TOKEN")
if hf_token:
os.environ["HF_TOKEN"] = hf_token
# Set model name
if model_size == "7B":
model_name = "Qwen/Qwen2.5-7B-Instruct"
batch_size = 1
grad_accum = 16 # Higher accumulation for larger model
else:
model_name = "Qwen/Qwen2.5-3B-Instruct"
batch_size = 1
grad_accum = 8
if training_method == "sft":
# SFT training (baseline)
if training_mode == "test":
command = f"python src/training/run_sft_test.py --model-name {model_name}"
description = f"🧪 SFT TEST: 50 examples, {model_size} (4-bit LoRA)"
expected_time = "15-25 minutes" if model_size == "7B" else "10-15 minutes"
else:
command = f"python src/training/run_sft_full.py --model-name {model_name}"
description = f"🚀 SFT PRODUCTION: 439 examples, 3 epochs, {model_size}"
expected_time = "90-120 minutes" if model_size == "7B" else "40-60 minutes"
else:
# DPO training (recommended)
if training_mode == "test":
command = f"""python src/training/run_dpo_training.py \
--model_name {model_name} \
--dataset_path data/hf_rlhf_dataset_v1 \
--output_dir ./dpo_test_checkpoints_{model_size.lower()} \
--num_train_epochs 1 \
--per_device_train_batch_size {batch_size} \
--gradient_accumulation_steps {grad_accum // 2} \
--learning_rate 5e-5 \
--beta 0.1 \
--use_4bit \
--use_lora \
--max_length 2048 \
--max_prompt_length 1024 \
--eval_steps 50 \
--save_steps 100"""
description = f"⭐ DPO TEST: 100 pairs, {model_size} (4-bit LoRA)"
expected_time = "20-30 minutes" if model_size == "7B" else "15-20 minutes"
else:
command = f"""python src/training/run_dpo_training.py \
--model_name {model_name} \
--dataset_path data/hf_rlhf_dataset_v1 \
--output_dir ./dpo_checkpoints_{model_size.lower()} \
--num_train_epochs 3 \
--per_device_train_batch_size {batch_size} \
--gradient_accumulation_steps {grad_accum} \
--learning_rate 5e-5 \
--beta 0.1 \
--use_4bit \
--use_lora \
--max_length 2048 \
--max_prompt_length 1024 \
--eval_steps 50 \
--save_steps 100"""
description = f"⭐ DPO PRODUCTION: 4399 pairs, 3 epochs, {model_size}"
expected_time = "120-180 minutes" if model_size == "7B" else "60-90 minutes"
yield f"{description}\n⏱️ Expected time: {expected_time}\n\n{'='*60}\n\n"
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)
output = f"{description}\n⏱️ Expected time: {expected_time}\n\n{'='*60}\n\n"
for line in iter(process.stdout.readline, ''):
output += line
yield output
process.stdout.close()
return_code = process.wait()
if return_code == 0:
yield output + "\n\n" + "="*60 + "\n✅ Training finished successfully!\n" + "="*60
else:
yield output + "\n\n" + "="*60 + f"\n❌ Training failed with return code {return_code}!\n" + "="*60
# Create Gradio interface
demo = gr.Interface(
fn=run_training,
inputs=[
gr.Radio(
choices=["dpo", "sft"],
value="dpo",
label="Training Method",
info="⭐ DPO recommended: +26% F1, -63% hallucinations vs SFT"
),
gr.Radio(
choices=["test", "full"],
value="test",
label="Training Mode",
info="Test: Quick validation. Full: Production training"
),
gr.Radio(
choices=["3B", "7B"],
value="7B",
label="Model Size",
info="7B: Better accuracy (+5-10% F1), 2x slower. 3B: Faster, good baseline"
)
],
outputs=gr.Textbox(
lines=30,
label="Training Log",
show_copy_button=True
),
title="🚀 Qwen2.5 Fine-Tuning: SFT vs DPO (3B/7B)",
description="""
Fine-tune Qwen2.5 (3B or 7B) for frequent itemset extraction using two methods:
### ⭐ DPO (Direct Preference Optimization) - Recommended
- **Dataset**: [itemset-extraction-rlhf-v1](https://huggingface.co/datasets/OliverSlivka/itemset-extraction-rlhf-v1) (auto-downloaded)
- **Data**: 4,399 preference pairs (chosen vs rejected responses)
- **Results (3B)**: F1=0.82, Hallucinations=3%, JSON Parse=98%
- **Results (7B)**: F1=0.87+ (estimated), Better reasoning
- **Test Mode**: 100 pairs, 1 epoch, ~15-30 min (3B/7B)
- **Full Mode**: 4,399 pairs, 3 epochs, ~60-180 min (3B/7B)
### SFT (Supervised Fine-Tuning) - Baseline
- **Dataset**: [itemset-extraction-v2](https://huggingface.co/datasets/OliverSlivka/itemset-extraction-v2)
- **Data**: 439 training examples
- **Results (3B)**: F1=0.65, Hallucinations=8%, JSON Parse=95%
- **Test Mode**: 50 examples, 1 epoch, ~10-25 min (3B/7B)
- **Full Mode**: 439 examples, 3 epochs, ~40-120 min (3B/7B)
### Model Comparison
- **3B**: Faster training, 8-10GB VRAM, good baseline
- **7B**: Better accuracy (+5-10% F1), 16-18GB VRAM, recommended for production
**Both use 4-bit quantization + LoRA to fit in GPU memory.**
⚠️ **GPU Requirements:**
- 3B: Zero GPU (A10G 16GB) ✅
- 7B: Zero GPU (may timeout) or Persistent GPU (L4 24GB) ✅ Recommended
""",
article="""
## Output Models
### DPO Models (⭐ Recommended)
- **3B Test**: `OliverSlivka/qwen2.5-3b-itemset-dpo-test`
- **3B Full**: `OliverSlivka/qwen2.5-3b-itemset-dpo`
- **7B Test**: `OliverSlivka/qwen2.5-7b-itemset-dpo-test`
- **7B Full**: `OliverSlivka/qwen2.5-7b-itemset-dpo` ⭐ Best
### SFT Models (Baseline)
- **3B Test**: `OliverSlivka/qwen2.5-3b-itemset-test`
- **3B Full**: `OliverSlivka/qwen2.5-3b-itemset-extractor`
- **7B Full**: `OliverSlivka/qwen2.5-7b-itemset-extractor`
## Why DPO > SFT?
| Metric | SFT (3B) | DPO (3B) | DPO (7B) |
|--------|----------|----------|----------|
| F1 Score | 0.65 | 0.82 | **0.87+** |
| Hallucinations | 8% | 3% | **<2%** |
| JSON Parse | 95% | 98% | **99%** |
| Exact Match | 0.45 | 0.55 | **0.65+** |
DPO learns from preference pairs (correct vs errors) while SFT only learns from correct answers.
7B model provides better reasoning and fewer edge case failures.
## Resources
- **Project**: [itemsety-qwen-finetuning](https://github.com/oliversl1vka/itemsety-qwen-finetuning)
- **DPO Paper**: [Direct Preference Optimization](https://arxiv.org/abs/2305.18290)
- **SFT Dataset**: [itemset-extraction-v2](https://huggingface.co/datasets/OliverSlivka/itemset-extraction-v2)
- **RLHF Dataset**: [itemset-extraction-rlhf-v1](https://huggingface.co/datasets/OliverSlivka/itemset-extraction-rlhf-v1)
## Training Tips
**For 7B models:**
- Use persistent GPU (L4 24GB) for full training (>2h)
- Zero GPU works for test mode
- Expect +5-10% better F1 vs 3B
- 2x slower but worth it for production
"""
)
if __name__ == "__main__":
demo.launch()
|