Spaces:
Configuration error
Configuration error
Commit ·
d4398e6
0
Parent(s):
Initial release of Auto-FineTune-Ops
Browse files- .gitignore +51 -0
- .streamlit/config.toml +11 -0
- Auto_FineTune_Ops_Colab.ipynb +456 -0
- PROJECT_HIGHLIGHTS.md +54 -0
- README.md +152 -0
- agents/__init__.py +7 -0
- agents/data_architect.py +505 -0
- agents/the_judge.py +566 -0
- agents/training_pilot.py +528 -0
- app.py +1500 -0
- configs/default_config.yaml +160 -0
- main.py +482 -0
- preprocessing/__init__.py +8 -0
- preprocessing/augmentation.py +182 -0
- preprocessing/dataset_balancing.py +97 -0
- preprocessing/deduplication.py +84 -0
- preprocessing/output_formatter.py +150 -0
- preprocessing/pii_filter.py +165 -0
- preprocessing/pipeline.py +253 -0
- preprocessing/quality_filters.py +172 -0
- preprocessing/system_prompt.py +80 -0
- preprocessing/text_cleaning.py +124 -0
- preprocessing/tokenization.py +147 -0
- preprocessing/train_val_split.py +41 -0
- requirements.txt +42 -0
- scripts/__init__.py +5 -0
- scripts/deploy.py +375 -0
.gitignore
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual Environment
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
.env
|
| 27 |
+
.venv/
|
| 28 |
+
|
| 29 |
+
# Environment Variables
|
| 30 |
+
.env
|
| 31 |
+
.env.local
|
| 32 |
+
.env.development.local
|
| 33 |
+
.env.test.local
|
| 34 |
+
.env.production.local
|
| 35 |
+
|
| 36 |
+
# IDE
|
| 37 |
+
.idea/
|
| 38 |
+
.vscode/
|
| 39 |
+
*.swp
|
| 40 |
+
*.swo
|
| 41 |
+
|
| 42 |
+
# Project Output
|
| 43 |
+
output/
|
| 44 |
+
logs/
|
| 45 |
+
reports/
|
| 46 |
+
models/
|
| 47 |
+
processed_data/
|
| 48 |
+
|
| 49 |
+
# OS
|
| 50 |
+
.DS_Store
|
| 51 |
+
Thumbs.db
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
primaryColor = "#6366f1"
|
| 3 |
+
backgroundColor = "#0f0f23"
|
| 4 |
+
secondaryBackgroundColor = "#1a1a2e"
|
| 5 |
+
textColor = "#e2e8f0"
|
| 6 |
+
font = "sans serif"
|
| 7 |
+
|
| 8 |
+
[server]
|
| 9 |
+
headless = true
|
| 10 |
+
enableCORS = false
|
| 11 |
+
enableXsrfProtection = true
|
Auto_FineTune_Ops_Colab.ipynb
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 🤖 Auto-FineTune-Ops: One-Click Fine-Tuning Pipeline\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"**Run this notebook on Google Colab (with GPU) or Kaggle to fine-tune your LLM!**\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook combines all the agents:\n",
|
| 12 |
+
"- **DataArchitectAgent**: Cleans and formats your data\n",
|
| 13 |
+
"- **TrainingPilot**: Fine-tunes with Unsloth (ultra-fast LoRA)\n",
|
| 14 |
+
"- **TheJudge**: Evaluates base vs fine-tuned with LLM-as-Judge\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"---"
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"source": [
|
| 23 |
+
"## 1️⃣ Setup - Install Dependencies"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": null,
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"outputs": [],
|
| 31 |
+
"source": [
|
| 32 |
+
"%%capture\n",
|
| 33 |
+
"# Install Unsloth (must be first!)\n",
|
| 34 |
+
"!pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 35 |
+
"!pip install --no-deps trl peft accelerate bitsandbytes\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"# Install other dependencies\n",
|
| 38 |
+
"!pip install datasets transformers rich pandas openai anthropic"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "markdown",
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"source": [
|
| 45 |
+
"## 2️⃣ Configuration\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"Set your training goal and upload your data!"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"cell_type": "code",
|
| 52 |
+
"execution_count": null,
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"#@title ⚙️ Configuration\n",
|
| 57 |
+
"#@markdown ### Training Settings\n",
|
| 58 |
+
"GOAL = \"medical_assistant\" #@param {type:\"string\"}\n",
|
| 59 |
+
"BASE_MODEL = \"unsloth/llama-3-8b-bnb-4bit\" #@param [\"unsloth/llama-3-8b-bnb-4bit\", \"unsloth/mistral-7b-bnb-4bit\", \"unsloth/gemma-7b-bnb-4bit\"]\n",
|
| 60 |
+
"MAX_SEQ_LENGTH = 2048 #@param {type:\"integer\"}\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"#@markdown ### Evaluation Settings\n",
|
| 63 |
+
"RUN_EVALUATION = True #@param {type:\"boolean\"}\n",
|
| 64 |
+
"JUDGE_MODEL = \"gpt-4o\" #@param [\"gpt-4o\", \"claude-3-5-sonnet-20241022\"]\n",
|
| 65 |
+
"NUM_EVAL_SAMPLES = 20 #@param {type:\"integer\"}\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"#@markdown ### API Keys (for evaluation only)\n",
|
| 68 |
+
"OPENAI_API_KEY = \"\" #@param {type:\"string\"}\n",
|
| 69 |
+
"ANTHROPIC_API_KEY = \"\" #@param {type:\"string\"}\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"import os\n",
|
| 72 |
+
"if OPENAI_API_KEY:\n",
|
| 73 |
+
" os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n",
|
| 74 |
+
"if ANTHROPIC_API_KEY:\n",
|
| 75 |
+
" os.environ[\"ANTHROPIC_API_KEY\"] = ANTHROPIC_API_KEY\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"print(f\"✅ Goal: {GOAL}\")\n",
|
| 78 |
+
"print(f\"✅ Base Model: {BASE_MODEL}\")"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "markdown",
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"source": [
|
| 85 |
+
"## 3️⃣ Upload Your Data\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"Upload a CSV or JSON file with instruction-response pairs."
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": [
|
| 96 |
+
"from google.colab import files\n",
|
| 97 |
+
"import pandas as pd\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"print(\"📂 Upload your dataset (CSV or JSON):\")\n",
|
| 100 |
+
"uploaded = files.upload()\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"# Get the uploaded file name\n",
|
| 103 |
+
"DATA_FILE = list(uploaded.keys())[0]\n",
|
| 104 |
+
"print(f\"\\n✅ Uploaded: {DATA_FILE}\")\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# Preview the data\n",
|
| 107 |
+
"if DATA_FILE.endswith('.csv'):\n",
|
| 108 |
+
" df = pd.read_csv(DATA_FILE)\n",
|
| 109 |
+
"else:\n",
|
| 110 |
+
" df = pd.read_json(DATA_FILE, lines=DATA_FILE.endswith('.jsonl'))\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"print(f\"\\n📊 Dataset shape: {df.shape}\")\n",
|
| 113 |
+
"print(f\"📋 Columns: {list(df.columns)}\")\n",
|
| 114 |
+
"df.head(3)"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "markdown",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"source": [
|
| 121 |
+
"## 4️⃣ Stage 1: Data Preparation (DataArchitectAgent)"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": null,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [],
|
| 129 |
+
"source": [
|
| 130 |
+
"import json\n",
|
| 131 |
+
"import re\n",
|
| 132 |
+
"from dataclasses import dataclass, field\n",
|
| 133 |
+
"from typing import Optional, List, Dict, Tuple\n",
|
| 134 |
+
"import pandas as pd\n",
|
| 135 |
+
"from rich.console import Console\n",
|
| 136 |
+
"from rich.table import Table\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"console = Console()\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"@dataclass\n",
|
| 141 |
+
"class CleaningConfig:\n",
|
| 142 |
+
" min_instruction_length: int = 10\n",
|
| 143 |
+
" max_instruction_length: int = 2048\n",
|
| 144 |
+
" min_response_length: int = 20\n",
|
| 145 |
+
" max_response_length: int = 4096\n",
|
| 146 |
+
" remove_duplicates: bool = True\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"class DataArchitectAgent:\n",
|
| 149 |
+
" \"\"\"Autonomous data preparation agent.\"\"\"\n",
|
| 150 |
+
" \n",
|
| 151 |
+
" INSTRUCTION_PATTERNS = [r'instruction', r'prompt', r'question', r'query', r'user', r'input_text']\n",
|
| 152 |
+
" OUTPUT_PATTERNS = [r'output', r'response', r'answer', r'completion', r'assistant', r'target']\n",
|
| 153 |
+
" \n",
|
| 154 |
+
" def __init__(self, config=None):\n",
|
| 155 |
+
" self.config = config or CleaningConfig()\n",
|
| 156 |
+
" \n",
|
| 157 |
+
" def _detect_columns(self, df):\n",
|
| 158 |
+
" instruction_col, output_col = None, None\n",
|
| 159 |
+
" for col in df.columns:\n",
|
| 160 |
+
" col_lower = col.lower()\n",
|
| 161 |
+
" for pattern in self.INSTRUCTION_PATTERNS:\n",
|
| 162 |
+
" if re.search(pattern, col_lower) and not instruction_col:\n",
|
| 163 |
+
" instruction_col = col\n",
|
| 164 |
+
" for pattern in self.OUTPUT_PATTERNS:\n",
|
| 165 |
+
" if re.search(pattern, col_lower) and not output_col:\n",
|
| 166 |
+
" output_col = col\n",
|
| 167 |
+
" return instruction_col, output_col\n",
|
| 168 |
+
" \n",
|
| 169 |
+
" def process(self, df, goal):\n",
|
| 170 |
+
" console.print(\"[bold blue]🏗️ DATA ARCHITECT AGENT[/]\")\n",
|
| 171 |
+
" \n",
|
| 172 |
+
" # Detect columns\n",
|
| 173 |
+
" inst_col, out_col = self._detect_columns(df)\n",
|
| 174 |
+
" console.print(f\"📌 Detected: instruction='{inst_col}', output='{out_col}'\")\n",
|
| 175 |
+
" \n",
|
| 176 |
+
" if not inst_col or not out_col:\n",
|
| 177 |
+
" raise ValueError(\"Could not auto-detect columns. Please rename to 'instruction' and 'output'.\")\n",
|
| 178 |
+
" \n",
|
| 179 |
+
" # Clean\n",
|
| 180 |
+
" df_clean = df.dropna(subset=[inst_col, out_col])\n",
|
| 181 |
+
" if self.config.remove_duplicates:\n",
|
| 182 |
+
" df_clean = df_clean.drop_duplicates(subset=[inst_col])\n",
|
| 183 |
+
" \n",
|
| 184 |
+
" # Length filters\n",
|
| 185 |
+
" df_clean = df_clean[\n",
|
| 186 |
+
" (df_clean[inst_col].str.len() >= self.config.min_instruction_length) &\n",
|
| 187 |
+
" (df_clean[inst_col].str.len() <= self.config.max_instruction_length) &\n",
|
| 188 |
+
" (df_clean[out_col].str.len() >= self.config.min_response_length) &\n",
|
| 189 |
+
" (df_clean[out_col].str.len() <= self.config.max_response_length)\n",
|
| 190 |
+
" ]\n",
|
| 191 |
+
" \n",
|
| 192 |
+
" console.print(f\"✅ Cleaned: {len(df_clean)} rows (from {len(df)})\")\n",
|
| 193 |
+
" \n",
|
| 194 |
+
" # Format for training\n",
|
| 195 |
+
" system_prompt = f\"You are a specialized AI assistant for {goal}.\"\n",
|
| 196 |
+
" \n",
|
| 197 |
+
" formatted = []\n",
|
| 198 |
+
" for _, row in df_clean.iterrows():\n",
|
| 199 |
+
" formatted.append({\n",
|
| 200 |
+
" \"instruction\": str(row[inst_col]),\n",
|
| 201 |
+
" \"input\": \"\",\n",
|
| 202 |
+
" \"output\": str(row[out_col]),\n",
|
| 203 |
+
" \"system\": system_prompt\n",
|
| 204 |
+
" })\n",
|
| 205 |
+
" \n",
|
| 206 |
+
" # Save\n",
|
| 207 |
+
" output_path = f\"/content/{goal}_training.jsonl\"\n",
|
| 208 |
+
" with open(output_path, 'w') as f:\n",
|
| 209 |
+
" for item in formatted:\n",
|
| 210 |
+
" f.write(json.dumps(item) + '\\n')\n",
|
| 211 |
+
" \n",
|
| 212 |
+
" console.print(f\"💾 Saved to: {output_path}\")\n",
|
| 213 |
+
" return output_path, len(formatted)\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"# Run Data Agent\n",
|
| 216 |
+
"data_agent = DataArchitectAgent()\n",
|
| 217 |
+
"TRAINING_DATA_PATH, DATASET_SIZE = data_agent.process(df, GOAL)\n",
|
| 218 |
+
"print(f\"\\n✅ Dataset ready: {DATASET_SIZE} samples\")"
|
| 219 |
+
]
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"cell_type": "markdown",
|
| 223 |
+
"metadata": {},
|
| 224 |
+
"source": [
|
| 225 |
+
"## 5️⃣ Stage 2: Fine-Tuning (TrainingPilot)"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "code",
|
| 230 |
+
"execution_count": null,
|
| 231 |
+
"metadata": {},
|
| 232 |
+
"outputs": [],
|
| 233 |
+
"source": [
|
| 234 |
+
"from unsloth import FastLanguageModel\n",
|
| 235 |
+
"from datasets import load_dataset\n",
|
| 236 |
+
"from trl import SFTTrainer\n",
|
| 237 |
+
"from transformers import TrainingArguments\n",
|
| 238 |
+
"import torch\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"print(f\"🚀 GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 241 |
+
"print(f\"📊 VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
|
| 242 |
+
"\n",
|
| 243 |
+
"# Auto-configure hyperparameters based on dataset size\n",
|
| 244 |
+
"if DATASET_SIZE < 1000:\n",
|
| 245 |
+
" LORA_RANK, LORA_ALPHA, LR, EPOCHS = 8, 16, 2e-4, 5\n",
|
| 246 |
+
"elif DATASET_SIZE < 10000:\n",
|
| 247 |
+
" LORA_RANK, LORA_ALPHA, LR, EPOCHS = 16, 32, 1e-4, 3\n",
|
| 248 |
+
"else:\n",
|
| 249 |
+
" LORA_RANK, LORA_ALPHA, LR, EPOCHS = 32, 64, 5e-5, 2\n",
|
| 250 |
+
"\n",
|
| 251 |
+
"print(f\"\\n⚙️ Auto-configured for {DATASET_SIZE} samples:\")\n",
|
| 252 |
+
"print(f\" LoRA Rank: {LORA_RANK}, Alpha: {LORA_ALPHA}\")\n",
|
| 253 |
+
"print(f\" Learning Rate: {LR}, Epochs: {EPOCHS}\")"
|
| 254 |
+
]
|
| 255 |
+
},
|
| 256 |
+
{
|
| 257 |
+
"cell_type": "code",
|
| 258 |
+
"execution_count": null,
|
| 259 |
+
"metadata": {},
|
| 260 |
+
"outputs": [],
|
| 261 |
+
"source": [
|
| 262 |
+
"# Load model with Unsloth\n",
|
| 263 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 264 |
+
" model_name=BASE_MODEL,\n",
|
| 265 |
+
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
| 266 |
+
" dtype=None,\n",
|
| 267 |
+
" load_in_4bit=True,\n",
|
| 268 |
+
")\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"# Apply LoRA\n",
|
| 271 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 272 |
+
" model,\n",
|
| 273 |
+
" r=LORA_RANK,\n",
|
| 274 |
+
" lora_alpha=LORA_ALPHA,\n",
|
| 275 |
+
" lora_dropout=0,\n",
|
| 276 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 277 |
+
" bias=\"none\",\n",
|
| 278 |
+
" use_gradient_checkpointing=\"unsloth\",\n",
|
| 279 |
+
" random_state=42,\n",
|
| 280 |
+
")\n",
|
| 281 |
+
"\n",
|
| 282 |
+
"print(\"✅ Model loaded with LoRA!\")"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "code",
|
| 287 |
+
"execution_count": null,
|
| 288 |
+
"metadata": {},
|
| 289 |
+
"outputs": [],
|
| 290 |
+
"source": [
|
| 291 |
+
"# Load dataset\n",
|
| 292 |
+
"dataset = load_dataset('json', data_files=TRAINING_DATA_PATH, split='train')\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"# Format prompts\n",
|
| 295 |
+
"alpaca_template = \"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
|
| 296 |
+
"\n",
|
| 297 |
+
"### Instruction:\n",
|
| 298 |
+
"{instruction}\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"### Response:\n",
|
| 301 |
+
"{output}\"\"\"\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"def format_prompt(example):\n",
|
| 304 |
+
" return {\"text\": alpaca_template.format(**example)}\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"dataset = dataset.map(format_prompt)\n",
|
| 307 |
+
"print(f\"✅ Loaded {len(dataset)} training samples\")"
|
| 308 |
+
]
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"cell_type": "code",
|
| 312 |
+
"execution_count": null,
|
| 313 |
+
"metadata": {},
|
| 314 |
+
"outputs": [],
|
| 315 |
+
"source": [
|
| 316 |
+
"# Train!\n",
|
| 317 |
+
"trainer = SFTTrainer(\n",
|
| 318 |
+
" model=model,\n",
|
| 319 |
+
" tokenizer=tokenizer,\n",
|
| 320 |
+
" train_dataset=dataset,\n",
|
| 321 |
+
" dataset_text_field=\"text\",\n",
|
| 322 |
+
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
| 323 |
+
" args=TrainingArguments(\n",
|
| 324 |
+
" output_dir=f\"/content/{GOAL}_model\",\n",
|
| 325 |
+
" num_train_epochs=EPOCHS,\n",
|
| 326 |
+
" per_device_train_batch_size=4,\n",
|
| 327 |
+
" gradient_accumulation_steps=4,\n",
|
| 328 |
+
" learning_rate=LR,\n",
|
| 329 |
+
" warmup_ratio=0.03,\n",
|
| 330 |
+
" fp16=True,\n",
|
| 331 |
+
" logging_steps=10,\n",
|
| 332 |
+
" save_strategy=\"epoch\",\n",
|
| 333 |
+
" optim=\"adamw_8bit\",\n",
|
| 334 |
+
" seed=42,\n",
|
| 335 |
+
" ),\n",
|
| 336 |
+
")\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"print(\"🏋️ Training started...\")\n",
|
| 339 |
+
"trainer.train()\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"# Save\n",
|
| 342 |
+
"MODEL_PATH = f\"/content/{GOAL}_model_final\"\n",
|
| 343 |
+
"trainer.save_model(MODEL_PATH)\n",
|
| 344 |
+
"tokenizer.save_pretrained(MODEL_PATH)\n",
|
| 345 |
+
"print(f\"\\n✅ Model saved to: {MODEL_PATH}\")"
|
| 346 |
+
]
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"cell_type": "markdown",
|
| 350 |
+
"metadata": {},
|
| 351 |
+
"source": [
|
| 352 |
+
"## 6️⃣ Stage 3: Evaluation (TheJudge) - Optional"
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"cell_type": "code",
|
| 357 |
+
"execution_count": null,
|
| 358 |
+
"metadata": {},
|
| 359 |
+
"outputs": [],
|
| 360 |
+
"source": [
|
| 361 |
+
"if RUN_EVALUATION and (OPENAI_API_KEY or ANTHROPIC_API_KEY):\n",
|
| 362 |
+
" print(\"⚖️ Running Model Arena evaluation...\")\n",
|
| 363 |
+
" \n",
|
| 364 |
+
" # Simple evaluation - compare responses\n",
|
| 365 |
+
" FastLanguageModel.for_inference(model)\n",
|
| 366 |
+
" \n",
|
| 367 |
+
" # Sample prompts from dataset\n",
|
| 368 |
+
" test_prompts = [dataset[i][\"instruction\"] for i in range(min(NUM_EVAL_SAMPLES, len(dataset)))]\n",
|
| 369 |
+
" \n",
|
| 370 |
+
" print(f\"\\n📊 Evaluating on {len(test_prompts)} samples...\")\n",
|
| 371 |
+
" print(\"Note: Full arena evaluation requires loading base model separately.\")\n",
|
| 372 |
+
" print(\"For complete evaluation, use the full TheJudge agent locally.\")\n",
|
| 373 |
+
" \n",
|
| 374 |
+
" # Quick test generation\n",
|
| 375 |
+
" test_prompt = test_prompts[0] if test_prompts else \"Hello, how are you?\"\n",
|
| 376 |
+
" inputs = tokenizer(f\"### Instruction:\\n{test_prompt}\\n\\n### Response:\\n\", return_tensors=\"pt\").to(\"cuda\")\n",
|
| 377 |
+
" outputs = model.generate(**inputs, max_new_tokens=128, temperature=0.7)\n",
|
| 378 |
+
" response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
| 379 |
+
" \n",
|
| 380 |
+
" print(f\"\\n📝 Sample generation:\")\n",
|
| 381 |
+
" print(f\"Prompt: {test_prompt[:100]}...\")\n",
|
| 382 |
+
" print(f\"Response: {response.split('### Response:')[-1][:200]}...\")\n",
|
| 383 |
+
"else:\n",
|
| 384 |
+
" print(\"⏭️ Skipping evaluation (no API key or disabled)\")"
|
| 385 |
+
]
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"cell_type": "markdown",
|
| 389 |
+
"metadata": {},
|
| 390 |
+
"source": [
|
| 391 |
+
"## 7️⃣ Download Your Model"
|
| 392 |
+
]
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"cell_type": "code",
|
| 396 |
+
"execution_count": null,
|
| 397 |
+
"metadata": {},
|
| 398 |
+
"outputs": [],
|
| 399 |
+
"source": [
|
| 400 |
+
"# Option 1: Save to Google Drive\n",
|
| 401 |
+
"from google.colab import drive\n",
|
| 402 |
+
"drive.mount('/content/drive')\n",
|
| 403 |
+
"\n",
|
| 404 |
+
"!cp -r {MODEL_PATH} /content/drive/MyDrive/{GOAL}_finetuned_model\n",
|
| 405 |
+
"print(f\"✅ Model copied to Google Drive: /MyDrive/{GOAL}_finetuned_model\")"
|
| 406 |
+
]
|
| 407 |
+
},
|
| 408 |
+
{
|
| 409 |
+
"cell_type": "code",
|
| 410 |
+
"execution_count": null,
|
| 411 |
+
"metadata": {},
|
| 412 |
+
"outputs": [],
|
| 413 |
+
"source": [
|
| 414 |
+
"# Option 2: Push to HuggingFace Hub\n",
|
| 415 |
+
"# Uncomment and fill in your details:\n",
|
| 416 |
+
"\n",
|
| 417 |
+
"# from huggingface_hub import login\n",
|
| 418 |
+
"# login(token=\"YOUR_HF_TOKEN\")\n",
|
| 419 |
+
"# \n",
|
| 420 |
+
"# model.push_to_hub(\"your-username/your-model-name\")\n",
|
| 421 |
+
"# tokenizer.push_to_hub(\"your-username/your-model-name\")\n",
|
| 422 |
+
"# print(\"✅ Pushed to HuggingFace Hub!\")"
|
| 423 |
+
]
|
| 424 |
+
},
|
| 425 |
+
{
|
| 426 |
+
"cell_type": "markdown",
|
| 427 |
+
"metadata": {},
|
| 428 |
+
"source": [
|
| 429 |
+
"## 🎉 Done!\n",
|
| 430 |
+
"\n",
|
| 431 |
+
"Your fine-tuned model is ready! You can:\n",
|
| 432 |
+
"1. Download from Google Drive\n",
|
| 433 |
+
"2. Push to HuggingFace Hub\n",
|
| 434 |
+
"3. Use the FastAPI deployment script locally"
|
| 435 |
+
]
|
| 436 |
+
}
|
| 437 |
+
],
|
| 438 |
+
"metadata": {
|
| 439 |
+
"accelerator": "GPU",
|
| 440 |
+
"colab": {
|
| 441 |
+
"gpuType": "T4",
|
| 442 |
+
"provenance": []
|
| 443 |
+
},
|
| 444 |
+
"kernelspec": {
|
| 445 |
+
"display_name": "venv",
|
| 446 |
+
"language": "python",
|
| 447 |
+
"name": "python3"
|
| 448 |
+
},
|
| 449 |
+
"language_info": {
|
| 450 |
+
"name": "python",
|
| 451 |
+
"version": "3.13.1"
|
| 452 |
+
}
|
| 453 |
+
},
|
| 454 |
+
"nbformat": 4,
|
| 455 |
+
"nbformat_minor": 0
|
| 456 |
+
}
|
PROJECT_HIGHLIGHTS.md
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Auto-FineTune-Ops: Project Highlights
|
| 2 |
+
|
| 3 |
+
**Autonomous Machine Learning Pipeline for Production-Grade LLM Fine-Tuning**
|
| 4 |
+
|
| 5 |
+
Auto-FineTune-Ops is a comprehensive, no-code/low-code platform that democratizes access to state-of-the-art LLM fine-tuning. It automates the complex lifecycle of data preparation, training, evaluation, and deployment.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 🌟 Key Features
|
| 10 |
+
|
| 11 |
+
### 1. 🧠 Intelligent Preprocessing Engine
|
| 12 |
+
A modular, production-ready data pipeline with 10+ specialized modules:
|
| 13 |
+
- **Text Cleaning:** Auto-strip HTML, emojis, URLs, and normalize whitespace.
|
| 14 |
+
- **PII Redaction:** Detect and mask emails, phone numbers, and keys for security.
|
| 15 |
+
- **Deduplication:** Remove exact and semantic duplicates (using TF-IDF/Cosine Similarity).
|
| 16 |
+
- **Quality Filtering:** Filter by language, toxicity, and length constraints.
|
| 17 |
+
- **Advanced Formatting:** Auto-convert loose CSV/JSON into strict Chat Templates (ShareGPT/OpenAI).
|
| 18 |
+
|
| 19 |
+
### 2. ⚡ Hybrid Training Ecosystem
|
| 20 |
+
Flexible training workflows designed for all hardware setups:
|
| 21 |
+
- **Local GPU Power:** Leverages **Unsloth** for 2x faster training and 70% less memory usage (4-bit quantization).
|
| 22 |
+
- **Google Colab Bridge:** Seamless "No-GPU" fallback flow. Generate a ready-to-run Colab notebook to train on free cloud GPUs if local hardware is insufficient.
|
| 23 |
+
- **Custom Model Support:** Fine-tune any HuggingFace model (Llama 3, Mistral, Gemma, Phi-3, etc.).
|
| 24 |
+
|
| 25 |
+
### 3. ⚖️ Multi-Provider AI Judge Arena
|
| 26 |
+
Production-grade model evaluation using LLM-as-a-Judge:
|
| 27 |
+
- **Provider Agnostic:** Supports OpenAI (GPT-4o), Anthropic (Claude 3.5), Google (Gemini 1.5), and Groq (Llama 3).
|
| 28 |
+
- **Custom Endpoints:** Connect to local LLMs (Ollama/vLLM) as judges.
|
| 29 |
+
- **Comprehensive Metrics:** Automated scoring for Accuracy, Helpfulness, Clarity, and Tone.
|
| 30 |
+
- **Head-to-Head:** Win-rate visualization comparing Base Model vs. Fine-Tuned Model.
|
| 31 |
+
|
| 32 |
+
### 4. 🖥️ Interactive Streamlit Dashboard
|
| 33 |
+
A premium, dark-mode UI that abstracts away CLI complexity:
|
| 34 |
+
- **Project Management:** Manage datasets, models, and logs visually.
|
| 35 |
+
- **Real-time Monitoring:** Track training loss and progress live.
|
| 36 |
+
- **Visualization:** Interactive Plotly charts for evaluation results.
|
| 37 |
+
|
| 38 |
+
### 5. 🚀 One-Click Deployment
|
| 39 |
+
- **Instant API:** Export trained models as a production-ready **FastAPI** microservice.
|
| 40 |
+
- **Standardized Interface:** OpenAI-compatible `/generate` endpoints for easy integration into apps.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## 🔧 Technical Stack
|
| 45 |
+
- **Frontend:** Streamlit, Plotly
|
| 46 |
+
- **Core ML:** PyTorch, Transformers, PEFT, Unsloth, TRL
|
| 47 |
+
- **Data:** Pandas, NumPy, Scikit-learn
|
| 48 |
+
- **API:** FastAPI, Uvicorn
|
| 49 |
+
- **LLM Clients:** OpenAI SDK, Anthropic SDK
|
| 50 |
+
|
| 51 |
+
## 🛡️ Production Readiness
|
| 52 |
+
- **Modular Architecture:** Agent-based design (DataArchitect, TrainingPilot, TheJudge) allows easy extensibility.
|
| 53 |
+
- **Error Handling:** Robust fallback mechanisms and detailed logging.
|
| 54 |
+
- **Security:** PII masking and API key management best practices.
|
README.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🤖 Auto-FineTune-Ops
|
| 2 |
+
|
| 3 |
+
> **Autonomous End-to-End LLM Fine-Tuning Pipeline**
|
| 4 |
+
>
|
| 5 |
+
> From raw data to production API in one click. No ML expertise required.
|
| 6 |
+
|
| 7 |
+
[](https://www.python.org/)
|
| 8 |
+
[](https://streamlit.io/)
|
| 9 |
+
[](LICENSE)
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 🎯 What Is This?
|
| 14 |
+
|
| 15 |
+
Auto-FineTune-Ops is a **no-code/low-code platform** that automates the entire lifecycle of fine-tuning Large Language Models (LLMs). It handles:
|
| 16 |
+
|
| 17 |
+
1. **Data Ingestion:** Upload CSV, JSON, or JSONL files.
|
| 18 |
+
2. **Advanced Preprocessing:** 10+ modules for cleaning, PII redaction, deduplication, and formatting.
|
| 19 |
+
3. **Hybrid Training:** Train locally on GPU (Unsloth/LoRA) or generate a **Google Colab Notebook** for free cloud GPU training.
|
| 20 |
+
4. **AI Judge Evaluation:** Compare your fine-tuned model against the base model using GPT-4, Claude 3.5, Gemini, or Groq as a judge.
|
| 21 |
+
5. **One-Click Deployment:** Export your trained model as a production-ready FastAPI endpoint.
|
| 22 |
+
|
| 23 |
+
**All accessible via a premium, easy-to-use Streamlit Dashboard.**
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## ✨ Key Features
|
| 28 |
+
|
| 29 |
+
### 🧠 Intelligent Preprocessing
|
| 30 |
+
- **Text Cleaning:** Remove HTML, URLs, emojis, normalize whitespace.
|
| 31 |
+
- **PII Filter:** Redact emails, phone numbers, API keys.
|
| 32 |
+
- **Deduplication:** Remove exact and semantic (TF-IDF) duplicates.
|
| 33 |
+
- **Quality Filters:** Filter by length, language, toxicity.
|
| 34 |
+
- **Balancing:** Oversample/undersample classes for classification tasks.
|
| 35 |
+
- **Export Formats:** Auto-convert to OpenAI Chat, Completion, or Classification JSONL formats.
|
| 36 |
+
|
| 37 |
+
### ⚡ Flexible Training Workflows
|
| 38 |
+
- **Local GPU:** Uses **Unsloth** for ultra-fast 4-bit LoRA fine-tuning (2x faster, 70% less memory).
|
| 39 |
+
- **Google Colab Fallback:** Don't have a GPU? The app generates a ready-to-run Colab notebook for you. Download models back to the app for evaluation.
|
| 40 |
+
- **Custom Models:** Fine-tune any HuggingFace model (Llama 3, Mistral, Gemma, Phi-3, etc.).
|
| 41 |
+
|
| 42 |
+
### ⚖️ Multi-Provider AI Judge
|
| 43 |
+
Evaluate models head-to-head using:
|
| 44 |
+
- **OpenAI** (GPT-4o, GPT-4-turbo)
|
| 45 |
+
- **Anthropic** (Claude 3.5 Sonnet, Opus)
|
| 46 |
+
- **Google** (Gemini 1.5 Pro)
|
| 47 |
+
- **Groq** (Llama 3, Mixtral)
|
| 48 |
+
- **Custom Endpoints** (Ollama, vLLM)
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## 🚀 Quick Start
|
| 53 |
+
|
| 54 |
+
### 1. Installation
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
# Clone the repository
|
| 58 |
+
git clone https://github.com/your-username/Auto-FineTune-Ops.git
|
| 59 |
+
cd Auto-FineTune-Ops
|
| 60 |
+
|
| 61 |
+
# Create a virtual environment
|
| 62 |
+
python -m venv venv
|
| 63 |
+
# Windows:
|
| 64 |
+
.\venv\Scripts\activate
|
| 65 |
+
# Mac/Linux:
|
| 66 |
+
source venv/bin/activate
|
| 67 |
+
|
| 68 |
+
# Install dependencies
|
| 69 |
+
pip install -r requirements.txt
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 2. Launch the Dashboard
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
streamlit run app.py
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
Open your browser to the URL shown (usually `http://localhost:8501`).
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## 🛠️ Workflow Guide
|
| 83 |
+
|
| 84 |
+
### Step 1: Data Upload
|
| 85 |
+
- Upload your raw `CSV` or `JSON` file containing instruction-response pairs.
|
| 86 |
+
- The app automatically detects columns like `instruction`, `input`, `output`.
|
| 87 |
+
- Preview full dataset with pagination.
|
| 88 |
+
|
| 89 |
+
### Step 2: Preprocessing
|
| 90 |
+
- Configure cleaning rules (HTML removal, lowercase, etc.).
|
| 91 |
+
- Set PII filters (mask emails/phones).
|
| 92 |
+
- Enable semantic deduplication.
|
| 93 |
+
- Click **Run Pipeline** to clean and format your data.
|
| 94 |
+
|
| 95 |
+
### Step 3: Training
|
| 96 |
+
- **If you have a GPU:** Select a base model (e.g., Llama-3-8b) and click **Start Training**.
|
| 97 |
+
- **If you have no GPU:**
|
| 98 |
+
1. Download the preprocessed data.
|
| 99 |
+
2. Download the generated `Colab Notebook`.
|
| 100 |
+
3. Run training on Google Colab (Free Tier).
|
| 101 |
+
4. Upload the fine-tuned model results back to the app.
|
| 102 |
+
|
| 103 |
+
### Step 4: Evaluation
|
| 104 |
+
- Compare your fine-tuned model vs. the base model.
|
| 105 |
+
- Select an AI Judge (e.g., GPT-4o).
|
| 106 |
+
- Visualize win rates and quality scores (Accuracy, Helpfulness, Tone).
|
| 107 |
+
|
| 108 |
+
### Step 5: Deployment
|
| 109 |
+
- Deploy your model locally as a REST API:
|
| 110 |
+
```bash
|
| 111 |
+
python scripts/deploy.py --model ./output/models/your_model --port 8000
|
| 112 |
+
```
|
| 113 |
+
- Or push to HuggingFace Hub directly from the dashboard.
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
## 🏗️ Project Structure
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
ml_oops/
|
| 121 |
+
├── app.py # 🚀 Main Streamlit Dashboard
|
| 122 |
+
├── main.py # 🧠 CLI Orchestrator (Headless mode)
|
| 123 |
+
├── requirements.txt # Dependencies
|
| 124 |
+
├── agents/ # Core Logic Agents
|
| 125 |
+
│ ├── data_architect.py # Data Analysis & Cleaning
|
| 126 |
+
│ ├── training_pilot.py # Fine-Tuning Logic
|
| 127 |
+
│ └── the_judge.py # Evaluation Logic
|
| 128 |
+
├── preprocessing/ # Advanced Preprocessing Modules
|
| 129 |
+
│ ├── text_cleaning.py # Regex & Normalization
|
| 130 |
+
│ ├── pii_filter.py # PII Redaction
|
| 131 |
+
│ ├── deduplication.py # Semantic Dedupe
|
| 132 |
+
│ └── ...
|
| 133 |
+
├── configs/ # Configuration Files
|
| 134 |
+
└── output/ # Artifacts (Models, Logs, Reports)
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## 🤝 Contributing
|
| 140 |
+
|
| 141 |
+
Contributions are welcome! Please read `CONTRIBUTING.md` for details on our code of conduct and the process for submitting pull requests.
|
| 142 |
+
|
| 143 |
+
## 📜 License
|
| 144 |
+
|
| 145 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 146 |
+
|
| 147 |
+
---
|
| 148 |
+
|
| 149 |
+
<div align="center">
|
| 150 |
+
<b>Built for modern ML teams.</b><br>
|
| 151 |
+
<i>Replace weeks of manual engineering with minutes of automated ops.</i>
|
| 152 |
+
</div>
|
agents/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auto-FineTune-Ops Agents Package"""
|
| 2 |
+
|
| 3 |
+
from .data_architect import DataArchitectAgent
|
| 4 |
+
from .training_pilot import TrainingPilot
|
| 5 |
+
from .the_judge import TheJudge
|
| 6 |
+
|
| 7 |
+
__all__ = ["DataArchitectAgent", "TrainingPilot", "TheJudge"]
|
agents/data_architect.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DataArchitectAgent - Autonomous Data Preparation Agent
|
| 3 |
+
=======================================================
|
| 4 |
+
Takes raw CSV/JSON datasets and transforms them into high-quality
|
| 5 |
+
HuggingFace-ready JSONL format for fine-tuning.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from rich.console import Console
|
| 15 |
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
| 16 |
+
from rich.table import Table
|
| 17 |
+
|
| 18 |
+
console = Console()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DatasetAnalysis:
|
| 23 |
+
"""Analysis results for a dataset."""
|
| 24 |
+
total_rows: int
|
| 25 |
+
valid_rows: int
|
| 26 |
+
invalid_rows: int
|
| 27 |
+
duplicate_rows: int
|
| 28 |
+
detected_columns: Dict[str, str] # column_name -> detected_type
|
| 29 |
+
instruction_column: Optional[str] = None
|
| 30 |
+
input_column: Optional[str] = None
|
| 31 |
+
output_column: Optional[str] = None
|
| 32 |
+
quality_score: float = 0.0
|
| 33 |
+
issues: List[str] = field(default_factory=list)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class CleaningConfig:
|
| 38 |
+
"""Configuration for data cleaning."""
|
| 39 |
+
min_instruction_length: int = 10
|
| 40 |
+
max_instruction_length: int = 2048
|
| 41 |
+
min_response_length: int = 20
|
| 42 |
+
max_response_length: int = 4096
|
| 43 |
+
remove_duplicates: bool = True
|
| 44 |
+
remove_empty: bool = True
|
| 45 |
+
remove_special_chars: bool = False
|
| 46 |
+
quality_threshold: float = 0.7
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DataArchitectAgent:
|
| 50 |
+
"""
|
| 51 |
+
Autonomous agent for data preparation and cleaning.
|
| 52 |
+
|
| 53 |
+
This agent analyzes raw datasets, identifies instruction-response pairs,
|
| 54 |
+
cleans the data, and formats it for HuggingFace fine-tuning.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
# Common column name patterns for auto-detection
|
| 58 |
+
INSTRUCTION_PATTERNS = [
|
| 59 |
+
r'instruction', r'prompt', r'question', r'query', r'input_text',
|
| 60 |
+
r'human', r'user', r'request', r'ask', r'command'
|
| 61 |
+
]
|
| 62 |
+
INPUT_PATTERNS = [
|
| 63 |
+
r'context', r'input', r'background', r'reference', r'document'
|
| 64 |
+
]
|
| 65 |
+
OUTPUT_PATTERNS = [
|
| 66 |
+
r'output', r'response', r'answer', r'completion', r'reply',
|
| 67 |
+
r'assistant', r'bot', r'generated', r'target'
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
def __init__(self, config: Optional[CleaningConfig] = None):
|
| 71 |
+
"""Initialize the DataArchitectAgent."""
|
| 72 |
+
self.config = config or CleaningConfig()
|
| 73 |
+
self.analysis: Optional[DatasetAnalysis] = None
|
| 74 |
+
|
| 75 |
+
def load_dataset(self, path: str) -> pd.DataFrame:
|
| 76 |
+
"""
|
| 77 |
+
Load a dataset from CSV or JSON file.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
path: Path to the dataset file
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Loaded DataFrame
|
| 84 |
+
"""
|
| 85 |
+
path = Path(path)
|
| 86 |
+
|
| 87 |
+
if not path.exists():
|
| 88 |
+
raise FileNotFoundError(f"Dataset not found: {path}")
|
| 89 |
+
|
| 90 |
+
console.print(f"[bold blue]📂 Loading dataset:[/] {path}")
|
| 91 |
+
|
| 92 |
+
if path.suffix.lower() == '.csv':
|
| 93 |
+
df = pd.read_csv(path)
|
| 94 |
+
elif path.suffix.lower() in ['.json', '.jsonl']:
|
| 95 |
+
if path.suffix.lower() == '.jsonl':
|
| 96 |
+
df = pd.read_json(path, lines=True)
|
| 97 |
+
else:
|
| 98 |
+
df = pd.read_json(path)
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f"Unsupported file format: {path.suffix}")
|
| 101 |
+
|
| 102 |
+
console.print(f"[green]✓ Loaded {len(df)} rows with {len(df.columns)} columns[/]")
|
| 103 |
+
return df
|
| 104 |
+
|
| 105 |
+
def _match_column_pattern(self, column: str, patterns: List[str]) -> bool:
|
| 106 |
+
"""Check if a column name matches any of the given patterns."""
|
| 107 |
+
column_lower = column.lower()
|
| 108 |
+
for pattern in patterns:
|
| 109 |
+
if re.search(pattern, column_lower):
|
| 110 |
+
return True
|
| 111 |
+
return False
|
| 112 |
+
|
| 113 |
+
def _detect_column_type(self, column: str) -> str:
|
| 114 |
+
"""Detect the type of a column based on its name."""
|
| 115 |
+
if self._match_column_pattern(column, self.INSTRUCTION_PATTERNS):
|
| 116 |
+
return 'instruction'
|
| 117 |
+
elif self._match_column_pattern(column, self.INPUT_PATTERNS):
|
| 118 |
+
return 'input'
|
| 119 |
+
elif self._match_column_pattern(column, self.OUTPUT_PATTERNS):
|
| 120 |
+
return 'output'
|
| 121 |
+
return 'unknown'
|
| 122 |
+
|
| 123 |
+
def analyze_dataset(self, df: pd.DataFrame) -> DatasetAnalysis:
|
| 124 |
+
"""
|
| 125 |
+
Analyze a dataset to understand its structure and quality.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
df: Input DataFrame
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
DatasetAnalysis with detected columns and quality metrics
|
| 132 |
+
"""
|
| 133 |
+
console.print("\n[bold blue]🔍 Analyzing dataset structure...[/]")
|
| 134 |
+
|
| 135 |
+
# Detect column types
|
| 136 |
+
detected_columns = {}
|
| 137 |
+
instruction_col = None
|
| 138 |
+
input_col = None
|
| 139 |
+
output_col = None
|
| 140 |
+
|
| 141 |
+
for col in df.columns:
|
| 142 |
+
col_type = self._detect_column_type(col)
|
| 143 |
+
detected_columns[col] = col_type
|
| 144 |
+
|
| 145 |
+
if col_type == 'instruction' and instruction_col is None:
|
| 146 |
+
instruction_col = col
|
| 147 |
+
elif col_type == 'input' and input_col is None:
|
| 148 |
+
input_col = col
|
| 149 |
+
elif col_type == 'output' and output_col is None:
|
| 150 |
+
output_col = col
|
| 151 |
+
|
| 152 |
+
# Count issues
|
| 153 |
+
issues = []
|
| 154 |
+
valid_rows = 0
|
| 155 |
+
invalid_rows = 0
|
| 156 |
+
|
| 157 |
+
# Check for required columns
|
| 158 |
+
if instruction_col is None:
|
| 159 |
+
issues.append("❌ No instruction/prompt column detected")
|
| 160 |
+
if output_col is None:
|
| 161 |
+
issues.append("❌ No output/response column detected")
|
| 162 |
+
|
| 163 |
+
# Analyze row validity
|
| 164 |
+
for _, row in df.iterrows():
|
| 165 |
+
is_valid = True
|
| 166 |
+
|
| 167 |
+
if instruction_col:
|
| 168 |
+
inst_val = str(row.get(instruction_col, ''))
|
| 169 |
+
if len(inst_val) < self.config.min_instruction_length:
|
| 170 |
+
is_valid = False
|
| 171 |
+
elif len(inst_val) > self.config.max_instruction_length:
|
| 172 |
+
is_valid = False
|
| 173 |
+
else:
|
| 174 |
+
is_valid = False
|
| 175 |
+
|
| 176 |
+
if output_col:
|
| 177 |
+
out_val = str(row.get(output_col, ''))
|
| 178 |
+
if len(out_val) < self.config.min_response_length:
|
| 179 |
+
is_valid = False
|
| 180 |
+
elif len(out_val) > self.config.max_response_length:
|
| 181 |
+
is_valid = False
|
| 182 |
+
else:
|
| 183 |
+
is_valid = False
|
| 184 |
+
|
| 185 |
+
if is_valid:
|
| 186 |
+
valid_rows += 1
|
| 187 |
+
else:
|
| 188 |
+
invalid_rows += 1
|
| 189 |
+
|
| 190 |
+
# Count duplicates
|
| 191 |
+
duplicate_rows = 0
|
| 192 |
+
if instruction_col:
|
| 193 |
+
duplicate_rows = df[instruction_col].duplicated().sum()
|
| 194 |
+
|
| 195 |
+
# Calculate quality score
|
| 196 |
+
quality_score = valid_rows / len(df) if len(df) > 0 else 0.0
|
| 197 |
+
|
| 198 |
+
self.analysis = DatasetAnalysis(
|
| 199 |
+
total_rows=len(df),
|
| 200 |
+
valid_rows=valid_rows,
|
| 201 |
+
invalid_rows=invalid_rows,
|
| 202 |
+
duplicate_rows=duplicate_rows,
|
| 203 |
+
detected_columns=detected_columns,
|
| 204 |
+
instruction_column=instruction_col,
|
| 205 |
+
input_column=input_col,
|
| 206 |
+
output_column=output_col,
|
| 207 |
+
quality_score=quality_score,
|
| 208 |
+
issues=issues
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Display analysis results
|
| 212 |
+
self._display_analysis()
|
| 213 |
+
|
| 214 |
+
return self.analysis
|
| 215 |
+
|
| 216 |
+
def _display_analysis(self):
|
| 217 |
+
"""Display the analysis results in a formatted table."""
|
| 218 |
+
if not self.analysis:
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
table = Table(title="Dataset Analysis", show_header=True)
|
| 222 |
+
table.add_column("Metric", style="cyan")
|
| 223 |
+
table.add_column("Value", style="green")
|
| 224 |
+
|
| 225 |
+
table.add_row("Total Rows", str(self.analysis.total_rows))
|
| 226 |
+
table.add_row("Valid Rows", str(self.analysis.valid_rows))
|
| 227 |
+
table.add_row("Invalid Rows", str(self.analysis.invalid_rows))
|
| 228 |
+
table.add_row("Duplicate Rows", str(self.analysis.duplicate_rows))
|
| 229 |
+
table.add_row("Quality Score", f"{self.analysis.quality_score:.2%}")
|
| 230 |
+
|
| 231 |
+
console.print(table)
|
| 232 |
+
|
| 233 |
+
# Show detected columns
|
| 234 |
+
console.print("\n[bold]Detected Column Mappings:[/]")
|
| 235 |
+
console.print(f" • Instruction: [cyan]{self.analysis.instruction_column or 'Not detected'}[/]")
|
| 236 |
+
console.print(f" • Input/Context: [cyan]{self.analysis.input_column or 'Not detected'}[/]")
|
| 237 |
+
console.print(f" • Output/Response: [cyan]{self.analysis.output_column or 'Not detected'}[/]")
|
| 238 |
+
|
| 239 |
+
if self.analysis.issues:
|
| 240 |
+
console.print("\n[bold red]Issues Found:[/]")
|
| 241 |
+
for issue in self.analysis.issues:
|
| 242 |
+
console.print(f" {issue}")
|
| 243 |
+
|
| 244 |
+
def clean_data(
|
| 245 |
+
self,
|
| 246 |
+
df: pd.DataFrame,
|
| 247 |
+
instruction_col: Optional[str] = None,
|
| 248 |
+
input_col: Optional[str] = None,
|
| 249 |
+
output_col: Optional[str] = None
|
| 250 |
+
) -> pd.DataFrame:
|
| 251 |
+
"""
|
| 252 |
+
Clean and validate the dataset.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
df: Input DataFrame
|
| 256 |
+
instruction_col: Override instruction column name
|
| 257 |
+
input_col: Override input column name
|
| 258 |
+
output_col: Override output column name
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
Cleaned DataFrame
|
| 262 |
+
"""
|
| 263 |
+
console.print("\n[bold blue]🧹 Cleaning dataset...[/]")
|
| 264 |
+
|
| 265 |
+
# Use detected columns if not specified
|
| 266 |
+
if self.analysis:
|
| 267 |
+
instruction_col = instruction_col or self.analysis.instruction_column
|
| 268 |
+
input_col = input_col or self.analysis.input_column
|
| 269 |
+
output_col = output_col or self.analysis.output_column
|
| 270 |
+
|
| 271 |
+
if not instruction_col or not output_col:
|
| 272 |
+
raise ValueError("Instruction and output columns are required")
|
| 273 |
+
|
| 274 |
+
df_clean = df.copy()
|
| 275 |
+
original_count = len(df_clean)
|
| 276 |
+
|
| 277 |
+
with Progress(
|
| 278 |
+
SpinnerColumn(),
|
| 279 |
+
TextColumn("[progress.description]{task.description}"),
|
| 280 |
+
console=console
|
| 281 |
+
) as progress:
|
| 282 |
+
# Remove empty values
|
| 283 |
+
task = progress.add_task("Removing empty values...", total=None)
|
| 284 |
+
df_clean = df_clean.dropna(subset=[instruction_col, output_col])
|
| 285 |
+
progress.update(task, completed=True)
|
| 286 |
+
|
| 287 |
+
# Remove duplicates
|
| 288 |
+
if self.config.remove_duplicates:
|
| 289 |
+
task = progress.add_task("Removing duplicates...", total=None)
|
| 290 |
+
df_clean = df_clean.drop_duplicates(subset=[instruction_col])
|
| 291 |
+
progress.update(task, completed=True)
|
| 292 |
+
|
| 293 |
+
# Filter by length constraints
|
| 294 |
+
task = progress.add_task("Applying length filters...", total=None)
|
| 295 |
+
|
| 296 |
+
# Instruction length filter
|
| 297 |
+
df_clean = df_clean[
|
| 298 |
+
df_clean[instruction_col].str.len() >= self.config.min_instruction_length
|
| 299 |
+
]
|
| 300 |
+
df_clean = df_clean[
|
| 301 |
+
df_clean[instruction_col].str.len() <= self.config.max_instruction_length
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
# Response length filter
|
| 305 |
+
df_clean = df_clean[
|
| 306 |
+
df_clean[output_col].str.len() >= self.config.min_response_length
|
| 307 |
+
]
|
| 308 |
+
df_clean = df_clean[
|
| 309 |
+
df_clean[output_col].str.len() <= self.config.max_response_length
|
| 310 |
+
]
|
| 311 |
+
progress.update(task, completed=True)
|
| 312 |
+
|
| 313 |
+
# Clean text
|
| 314 |
+
task = progress.add_task("Cleaning text...", total=None)
|
| 315 |
+
df_clean[instruction_col] = df_clean[instruction_col].str.strip()
|
| 316 |
+
df_clean[output_col] = df_clean[output_col].str.strip()
|
| 317 |
+
if input_col and input_col in df_clean.columns:
|
| 318 |
+
df_clean[input_col] = df_clean[input_col].fillna('').str.strip()
|
| 319 |
+
progress.update(task, completed=True)
|
| 320 |
+
|
| 321 |
+
removed_count = original_count - len(df_clean)
|
| 322 |
+
console.print(f"[green]✓ Cleaned dataset: {len(df_clean)} rows remaining ({removed_count} removed)[/]")
|
| 323 |
+
|
| 324 |
+
return df_clean
|
| 325 |
+
|
| 326 |
+
def format_for_training(
|
| 327 |
+
self,
|
| 328 |
+
df: pd.DataFrame,
|
| 329 |
+
goal: str,
|
| 330 |
+
output_path: str,
|
| 331 |
+
instruction_col: Optional[str] = None,
|
| 332 |
+
input_col: Optional[str] = None,
|
| 333 |
+
output_col: Optional[str] = None
|
| 334 |
+
) -> str:
|
| 335 |
+
"""
|
| 336 |
+
Format the dataset into HuggingFace-ready JSONL.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
df: Cleaned DataFrame
|
| 340 |
+
goal: Training goal/purpose (e.g., 'medical_assistant')
|
| 341 |
+
output_path: Path to save the JSONL file
|
| 342 |
+
instruction_col: Instruction column name
|
| 343 |
+
input_col: Input/context column name
|
| 344 |
+
output_col: Output/response column name
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
Path to the created JSONL file
|
| 348 |
+
"""
|
| 349 |
+
console.print(f"\n[bold blue]📝 Formatting for training goal: [cyan]{goal}[/][/]")
|
| 350 |
+
|
| 351 |
+
# Use detected columns if not specified
|
| 352 |
+
if self.analysis:
|
| 353 |
+
instruction_col = instruction_col or self.analysis.instruction_column
|
| 354 |
+
input_col = input_col or self.analysis.input_column
|
| 355 |
+
output_col = output_col or self.analysis.output_column
|
| 356 |
+
|
| 357 |
+
if not instruction_col or not output_col:
|
| 358 |
+
raise ValueError("Instruction and output columns are required")
|
| 359 |
+
|
| 360 |
+
output_path = Path(output_path)
|
| 361 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 362 |
+
|
| 363 |
+
# Create system prompt based on goal
|
| 364 |
+
system_prompt = self._generate_system_prompt(goal)
|
| 365 |
+
|
| 366 |
+
formatted_data = []
|
| 367 |
+
|
| 368 |
+
with Progress(
|
| 369 |
+
SpinnerColumn(),
|
| 370 |
+
TextColumn("[progress.description]{task.description}"),
|
| 371 |
+
console=console
|
| 372 |
+
) as progress:
|
| 373 |
+
task = progress.add_task("Formatting entries...", total=len(df))
|
| 374 |
+
|
| 375 |
+
for _, row in df.iterrows():
|
| 376 |
+
instruction = str(row[instruction_col])
|
| 377 |
+
output = str(row[output_col])
|
| 378 |
+
context = str(row.get(input_col, '')) if input_col and input_col in df.columns else ''
|
| 379 |
+
|
| 380 |
+
# Format as Alpaca-style instruction format
|
| 381 |
+
entry = {
|
| 382 |
+
"instruction": instruction,
|
| 383 |
+
"input": context,
|
| 384 |
+
"output": output,
|
| 385 |
+
"system": system_prompt
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
# Also create chat format for compatibility
|
| 389 |
+
entry["conversations"] = [
|
| 390 |
+
{"role": "system", "content": system_prompt},
|
| 391 |
+
{"role": "user", "content": instruction + (f"\n\nContext: {context}" if context else "")},
|
| 392 |
+
{"role": "assistant", "content": output}
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
formatted_data.append(entry)
|
| 396 |
+
progress.advance(task)
|
| 397 |
+
|
| 398 |
+
# Write JSONL
|
| 399 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 400 |
+
for entry in formatted_data:
|
| 401 |
+
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
| 402 |
+
|
| 403 |
+
console.print(f"[green]✓ Created training file: {output_path}[/]")
|
| 404 |
+
console.print(f" • Total samples: {len(formatted_data)}")
|
| 405 |
+
console.print(f" • Format: JSONL (Alpaca-style + Chat format)")
|
| 406 |
+
|
| 407 |
+
return str(output_path)
|
| 408 |
+
|
| 409 |
+
def _generate_system_prompt(self, goal: str) -> str:
|
| 410 |
+
"""Generate a system prompt based on the training goal."""
|
| 411 |
+
goal_lower = goal.lower().replace('_', ' ').replace('-', ' ')
|
| 412 |
+
|
| 413 |
+
# Common goal templates
|
| 414 |
+
templates = {
|
| 415 |
+
'medical': "You are a knowledgeable medical assistant. Provide accurate, helpful medical information while always recommending users consult healthcare professionals for specific medical advice.",
|
| 416 |
+
'legal': "You are a legal information assistant. Provide helpful legal information while noting that you are not a lawyer and users should consult legal professionals for specific legal advice.",
|
| 417 |
+
'coding': "You are an expert programming assistant. Help users write clean, efficient, and well-documented code. Explain your solutions clearly.",
|
| 418 |
+
'customer': "You are a helpful customer service assistant. Be polite, professional, and focused on solving customer issues efficiently.",
|
| 419 |
+
'education': "You are an educational assistant. Explain concepts clearly and adapt your explanations to the user's level of understanding.",
|
| 420 |
+
'writing': "You are a skilled writing assistant. Help users improve their writing with clear, constructive feedback and suggestions.",
|
| 421 |
+
'assistant': "You are a helpful AI assistant. Provide accurate, useful responses while being conversational and engaging."
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
# Find matching template
|
| 425 |
+
for key, prompt in templates.items():
|
| 426 |
+
if key in goal_lower:
|
| 427 |
+
return prompt
|
| 428 |
+
|
| 429 |
+
# Default template
|
| 430 |
+
return f"You are a specialized AI assistant for {goal}. Provide helpful, accurate, and relevant responses to user queries."
|
| 431 |
+
|
| 432 |
+
def process(
|
| 433 |
+
self,
|
| 434 |
+
input_path: str,
|
| 435 |
+
output_path: str,
|
| 436 |
+
goal: str,
|
| 437 |
+
instruction_col: Optional[str] = None,
|
| 438 |
+
input_col: Optional[str] = None,
|
| 439 |
+
output_col: Optional[str] = None
|
| 440 |
+
) -> Tuple[str, DatasetAnalysis]:
|
| 441 |
+
"""
|
| 442 |
+
Complete end-to-end processing pipeline.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
input_path: Path to input dataset
|
| 446 |
+
output_path: Path for output JSONL
|
| 447 |
+
goal: Training goal
|
| 448 |
+
instruction_col: Override instruction column
|
| 449 |
+
input_col: Override input column
|
| 450 |
+
output_col: Override output column
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
Tuple of (output_path, analysis)
|
| 454 |
+
"""
|
| 455 |
+
console.print("\n" + "="*60)
|
| 456 |
+
console.print("[bold magenta]🏗️ DATA ARCHITECT AGENT[/]")
|
| 457 |
+
console.print("="*60)
|
| 458 |
+
|
| 459 |
+
# Load
|
| 460 |
+
df = self.load_dataset(input_path)
|
| 461 |
+
|
| 462 |
+
# Analyze
|
| 463 |
+
analysis = self.analyze_dataset(df)
|
| 464 |
+
|
| 465 |
+
# Check quality
|
| 466 |
+
if analysis.quality_score < self.config.quality_threshold:
|
| 467 |
+
console.print(f"[yellow]⚠️ Warning: Quality score ({analysis.quality_score:.2%}) below threshold ({self.config.quality_threshold:.2%})[/]")
|
| 468 |
+
|
| 469 |
+
# Clean
|
| 470 |
+
df_clean = self.clean_data(
|
| 471 |
+
df,
|
| 472 |
+
instruction_col=instruction_col or analysis.instruction_column,
|
| 473 |
+
input_col=input_col or analysis.input_column,
|
| 474 |
+
output_col=output_col or analysis.output_column
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Format
|
| 478 |
+
final_path = self.format_for_training(
|
| 479 |
+
df_clean,
|
| 480 |
+
goal=goal,
|
| 481 |
+
output_path=output_path,
|
| 482 |
+
instruction_col=instruction_col or analysis.instruction_column,
|
| 483 |
+
input_col=input_col or analysis.input_column,
|
| 484 |
+
output_col=output_col or analysis.output_column
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
console.print("\n[bold green]✅ Data preparation complete![/]")
|
| 488 |
+
|
| 489 |
+
return final_path, analysis
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
if __name__ == "__main__":
|
| 493 |
+
# Example usage
|
| 494 |
+
import sys
|
| 495 |
+
|
| 496 |
+
if len(sys.argv) < 3:
|
| 497 |
+
print("Usage: python data_architect.py <input_file> <goal>")
|
| 498 |
+
sys.exit(1)
|
| 499 |
+
|
| 500 |
+
input_file = sys.argv[1]
|
| 501 |
+
goal = sys.argv[2]
|
| 502 |
+
output_file = f"./output/processed_data/{goal}_training.jsonl"
|
| 503 |
+
|
| 504 |
+
agent = DataArchitectAgent()
|
| 505 |
+
agent.process(input_file, output_file, goal)
|
agents/the_judge.py
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TheJudge - LLM-as-a-Judge Evaluation Agent
|
| 3 |
+
=============================================
|
| 4 |
+
Runs a 'Model Arena' comparing base vs fine-tuned models using
|
| 5 |
+
a Judge LLM (GPT-4o or Claude 3.5) to score responses.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import random
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from enum import Enum
|
| 16 |
+
|
| 17 |
+
from rich.console import Console
|
| 18 |
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
|
| 19 |
+
from rich.table import Table
|
| 20 |
+
from rich.panel import Panel
|
| 21 |
+
from rich.markdown import Markdown
|
| 22 |
+
|
| 23 |
+
console = Console()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class JudgeModel(Enum):
|
| 27 |
+
"""Supported Judge LLM models."""
|
| 28 |
+
GPT4O = "gpt-4o"
|
| 29 |
+
CLAUDE_35_SONNET = "claude-3-5-sonnet-20241022"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class Verdict:
|
| 34 |
+
"""Verdict from a single comparison."""
|
| 35 |
+
prompt: str
|
| 36 |
+
response_a: str # Base model
|
| 37 |
+
response_b: str # Fine-tuned model
|
| 38 |
+
winner: str # 'A', 'B', or 'TIE'
|
| 39 |
+
score_a: int # 1-10
|
| 40 |
+
score_b: int # 1-10
|
| 41 |
+
reasoning: str
|
| 42 |
+
criteria_scores: Dict[str, Dict[str, int]] = field(default_factory=dict)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class ArenaResult:
|
| 47 |
+
"""Complete arena evaluation results."""
|
| 48 |
+
verdicts: List[Verdict]
|
| 49 |
+
base_model_wins: int
|
| 50 |
+
finetuned_wins: int
|
| 51 |
+
ties: int
|
| 52 |
+
base_model_avg_score: float
|
| 53 |
+
finetuned_avg_score: float
|
| 54 |
+
win_rate: float
|
| 55 |
+
total_comparisons: int
|
| 56 |
+
evaluation_time: float
|
| 57 |
+
judge_model: str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TheJudge:
|
| 61 |
+
"""
|
| 62 |
+
Model Arena evaluation agent using LLM-as-a-Judge.
|
| 63 |
+
|
| 64 |
+
Compares base model vs fine-tuned model responses and
|
| 65 |
+
provides detailed scoring based on multiple criteria.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
EVALUATION_CRITERIA = [
|
| 69 |
+
("helpfulness", "How helpful and useful is the response?"),
|
| 70 |
+
("accuracy", "How accurate and factually correct is the response?"),
|
| 71 |
+
("relevance", "How relevant is the response to the user's query?"),
|
| 72 |
+
("clarity", "How clear and well-structured is the response?"),
|
| 73 |
+
("completeness", "How complete and thorough is the response?")
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
JUDGE_PROMPT = """You are an expert evaluator comparing two AI assistant responses. Your task is to evaluate which response is better based on multiple criteria.
|
| 77 |
+
|
| 78 |
+
## User Query
|
| 79 |
+
{prompt}
|
| 80 |
+
|
| 81 |
+
## Response A
|
| 82 |
+
{response_a}
|
| 83 |
+
|
| 84 |
+
## Response B
|
| 85 |
+
{response_b}
|
| 86 |
+
|
| 87 |
+
## Evaluation Criteria
|
| 88 |
+
For each criterion, rate both responses on a scale of 1-10:
|
| 89 |
+
1. Helpfulness: How helpful and useful is the response?
|
| 90 |
+
2. Accuracy: How accurate and factually correct is the response?
|
| 91 |
+
3. Relevance: How relevant is the response to the user's query?
|
| 92 |
+
4. Clarity: How clear and well-structured is the response?
|
| 93 |
+
5. Completeness: How complete and thorough is the response?
|
| 94 |
+
|
| 95 |
+
## Instructions
|
| 96 |
+
1. Evaluate both responses fairly and objectively
|
| 97 |
+
2. Consider the strengths and weaknesses of each response
|
| 98 |
+
3. Provide specific reasoning for your evaluation
|
| 99 |
+
4. Determine the overall winner (A, B, or TIE)
|
| 100 |
+
|
| 101 |
+
## Output Format (JSON)
|
| 102 |
+
{{
|
| 103 |
+
"helpfulness": {{"A": <1-10>, "B": <1-10>}},
|
| 104 |
+
"accuracy": {{"A": <1-10>, "B": <1-10>}},
|
| 105 |
+
"relevance": {{"A": <1-10>, "B": <1-10>}},
|
| 106 |
+
"clarity": {{"A": <1-10>, "B": <1-10>}},
|
| 107 |
+
"completeness": {{"A": <1-10>, "B": <1-10>}},
|
| 108 |
+
"overall_score_a": <1-10>,
|
| 109 |
+
"overall_score_b": <1-10>,
|
| 110 |
+
"winner": "<A|B|TIE>",
|
| 111 |
+
"reasoning": "<detailed explanation of your evaluation>"
|
| 112 |
+
}}
|
| 113 |
+
|
| 114 |
+
Respond with ONLY the JSON object, no additional text."""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
judge_model: JudgeModel = JudgeModel.GPT4O,
|
| 119 |
+
temperature: float = 0.2,
|
| 120 |
+
max_tokens: int = 1024
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Initialize TheJudge.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
judge_model: Which LLM to use as judge
|
| 127 |
+
temperature: Sampling temperature for judge
|
| 128 |
+
max_tokens: Max tokens for judge response
|
| 129 |
+
"""
|
| 130 |
+
self.judge_model = judge_model
|
| 131 |
+
self.temperature = temperature
|
| 132 |
+
self.max_tokens = max_tokens
|
| 133 |
+
self._client = None
|
| 134 |
+
|
| 135 |
+
def _get_client(self):
|
| 136 |
+
"""Get or create the API client."""
|
| 137 |
+
if self._client is not None:
|
| 138 |
+
return self._client
|
| 139 |
+
|
| 140 |
+
if self.judge_model == JudgeModel.GPT4O:
|
| 141 |
+
try:
|
| 142 |
+
from openai import OpenAI
|
| 143 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 144 |
+
if not api_key:
|
| 145 |
+
raise ValueError("OPENAI_API_KEY environment variable not set")
|
| 146 |
+
self._client = OpenAI(api_key=api_key)
|
| 147 |
+
except ImportError:
|
| 148 |
+
raise ImportError("OpenAI package required. Install with: pip install openai")
|
| 149 |
+
else:
|
| 150 |
+
try:
|
| 151 |
+
from anthropic import Anthropic
|
| 152 |
+
api_key = os.getenv("ANTHROPIC_API_KEY")
|
| 153 |
+
if not api_key:
|
| 154 |
+
raise ValueError("ANTHROPIC_API_KEY environment variable not set")
|
| 155 |
+
self._client = Anthropic(api_key=api_key)
|
| 156 |
+
except ImportError:
|
| 157 |
+
raise ImportError("Anthropic package required. Install with: pip install anthropic")
|
| 158 |
+
|
| 159 |
+
return self._client
|
| 160 |
+
|
| 161 |
+
def _call_judge(self, prompt: str) -> str:
|
| 162 |
+
"""Call the judge LLM."""
|
| 163 |
+
client = self._get_client()
|
| 164 |
+
|
| 165 |
+
if self.judge_model == JudgeModel.GPT4O:
|
| 166 |
+
response = client.chat.completions.create(
|
| 167 |
+
model=self.judge_model.value,
|
| 168 |
+
messages=[{"role": "user", "content": prompt}],
|
| 169 |
+
temperature=self.temperature,
|
| 170 |
+
max_tokens=self.max_tokens
|
| 171 |
+
)
|
| 172 |
+
return response.choices[0].message.content
|
| 173 |
+
else:
|
| 174 |
+
response = client.messages.create(
|
| 175 |
+
model=self.judge_model.value,
|
| 176 |
+
max_tokens=self.max_tokens,
|
| 177 |
+
temperature=self.temperature,
|
| 178 |
+
messages=[{"role": "user", "content": prompt}]
|
| 179 |
+
)
|
| 180 |
+
return response.content[0].text
|
| 181 |
+
|
| 182 |
+
def _parse_verdict(self, response: str, prompt: str, resp_a: str, resp_b: str) -> Verdict:
|
| 183 |
+
"""Parse the judge's response into a Verdict."""
|
| 184 |
+
try:
|
| 185 |
+
# Clean response (remove markdown code blocks if present)
|
| 186 |
+
clean_response = response.strip()
|
| 187 |
+
if clean_response.startswith("```"):
|
| 188 |
+
clean_response = clean_response.split("```")[1]
|
| 189 |
+
if clean_response.startswith("json"):
|
| 190 |
+
clean_response = clean_response[4:]
|
| 191 |
+
|
| 192 |
+
data = json.loads(clean_response)
|
| 193 |
+
|
| 194 |
+
criteria_scores = {}
|
| 195 |
+
for criterion, _ in self.EVALUATION_CRITERIA:
|
| 196 |
+
if criterion in data:
|
| 197 |
+
criteria_scores[criterion] = data[criterion]
|
| 198 |
+
|
| 199 |
+
return Verdict(
|
| 200 |
+
prompt=prompt,
|
| 201 |
+
response_a=resp_a,
|
| 202 |
+
response_b=resp_b,
|
| 203 |
+
winner=data.get("winner", "TIE"),
|
| 204 |
+
score_a=data.get("overall_score_a", 5),
|
| 205 |
+
score_b=data.get("overall_score_b", 5),
|
| 206 |
+
reasoning=data.get("reasoning", "No reasoning provided"),
|
| 207 |
+
criteria_scores=criteria_scores
|
| 208 |
+
)
|
| 209 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 210 |
+
console.print(f"[yellow]⚠️ Warning: Failed to parse judge response: {e}[/]")
|
| 211 |
+
return Verdict(
|
| 212 |
+
prompt=prompt,
|
| 213 |
+
response_a=resp_a,
|
| 214 |
+
response_b=resp_b,
|
| 215 |
+
winner="TIE",
|
| 216 |
+
score_a=5,
|
| 217 |
+
score_b=5,
|
| 218 |
+
reasoning=f"Parse error: {response[:200]}...",
|
| 219 |
+
criteria_scores={}
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def generate_response(
|
| 223 |
+
self,
|
| 224 |
+
model: Any,
|
| 225 |
+
tokenizer: Any,
|
| 226 |
+
prompt: str,
|
| 227 |
+
max_new_tokens: int = 512
|
| 228 |
+
) -> str:
|
| 229 |
+
"""
|
| 230 |
+
Generate a response from a model.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
model: The language model
|
| 234 |
+
tokenizer: The tokenizer
|
| 235 |
+
prompt: Input prompt
|
| 236 |
+
max_new_tokens: Maximum tokens to generate
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Generated response string
|
| 240 |
+
"""
|
| 241 |
+
try:
|
| 242 |
+
from unsloth import FastLanguageModel
|
| 243 |
+
FastLanguageModel.for_inference(model)
|
| 244 |
+
except ImportError:
|
| 245 |
+
pass
|
| 246 |
+
|
| 247 |
+
# Format with Alpaca template
|
| 248 |
+
alpaca_prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| 249 |
+
|
| 250 |
+
### Instruction:
|
| 251 |
+
{prompt}
|
| 252 |
+
|
| 253 |
+
### Response:
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
inputs = tokenizer(alpaca_prompt, return_tensors="pt").to(model.device)
|
| 257 |
+
outputs = model.generate(
|
| 258 |
+
**inputs,
|
| 259 |
+
max_new_tokens=max_new_tokens,
|
| 260 |
+
temperature=0.7,
|
| 261 |
+
do_sample=True,
|
| 262 |
+
top_p=0.9,
|
| 263 |
+
pad_token_id=tokenizer.eos_token_id
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 267 |
+
|
| 268 |
+
# Extract just the response part
|
| 269 |
+
if "### Response:" in response:
|
| 270 |
+
response = response.split("### Response:")[-1].strip()
|
| 271 |
+
|
| 272 |
+
return response
|
| 273 |
+
|
| 274 |
+
def get_judge_verdict(
|
| 275 |
+
self,
|
| 276 |
+
prompt: str,
|
| 277 |
+
response_a: str,
|
| 278 |
+
response_b: str,
|
| 279 |
+
randomize: bool = True
|
| 280 |
+
) -> Verdict:
|
| 281 |
+
"""
|
| 282 |
+
Get judge verdict for a single comparison.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
prompt: Original user prompt
|
| 286 |
+
response_a: Response from model A (base)
|
| 287 |
+
response_b: Response from model B (fine-tuned)
|
| 288 |
+
randomize: Randomize A/B order to reduce position bias
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Verdict with scores and reasoning
|
| 292 |
+
"""
|
| 293 |
+
# Randomize order to reduce position bias
|
| 294 |
+
if randomize and random.random() > 0.5:
|
| 295 |
+
judge_prompt = self.JUDGE_PROMPT.format(
|
| 296 |
+
prompt=prompt,
|
| 297 |
+
response_a=response_b,
|
| 298 |
+
response_b=response_a
|
| 299 |
+
)
|
| 300 |
+
swapped = True
|
| 301 |
+
else:
|
| 302 |
+
judge_prompt = self.JUDGE_PROMPT.format(
|
| 303 |
+
prompt=prompt,
|
| 304 |
+
response_a=response_a,
|
| 305 |
+
response_b=response_b
|
| 306 |
+
)
|
| 307 |
+
swapped = False
|
| 308 |
+
|
| 309 |
+
# Get judge response
|
| 310 |
+
judge_response = self._call_judge(judge_prompt)
|
| 311 |
+
|
| 312 |
+
# Parse verdict
|
| 313 |
+
verdict = self._parse_verdict(
|
| 314 |
+
judge_response,
|
| 315 |
+
prompt,
|
| 316 |
+
response_a if not swapped else response_b,
|
| 317 |
+
response_b if not swapped else response_a
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Swap back if needed
|
| 321 |
+
if swapped:
|
| 322 |
+
verdict.response_a = response_a
|
| 323 |
+
verdict.response_b = response_b
|
| 324 |
+
verdict.score_a, verdict.score_b = verdict.score_b, verdict.score_a
|
| 325 |
+
if verdict.winner == "A":
|
| 326 |
+
verdict.winner = "B"
|
| 327 |
+
elif verdict.winner == "B":
|
| 328 |
+
verdict.winner = "A"
|
| 329 |
+
|
| 330 |
+
return verdict
|
| 331 |
+
|
| 332 |
+
def run_arena(
|
| 333 |
+
self,
|
| 334 |
+
base_model: Any,
|
| 335 |
+
finetuned_model: Any,
|
| 336 |
+
tokenizer: Any,
|
| 337 |
+
test_prompts: List[str],
|
| 338 |
+
finetuned_tokenizer: Optional[Any] = None
|
| 339 |
+
) -> ArenaResult:
|
| 340 |
+
"""
|
| 341 |
+
Run the complete Model Arena evaluation.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
base_model: Base model for comparison
|
| 345 |
+
finetuned_model: Fine-tuned model
|
| 346 |
+
tokenizer: Tokenizer for base model
|
| 347 |
+
test_prompts: List of evaluation prompts
|
| 348 |
+
finetuned_tokenizer: Optional separate tokenizer for fine-tuned model
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
ArenaResult with all verdicts and statistics
|
| 352 |
+
"""
|
| 353 |
+
console.print("\n" + "="*60)
|
| 354 |
+
console.print("[bold magenta]⚖️ THE JUDGE - MODEL ARENA[/]")
|
| 355 |
+
console.print("="*60)
|
| 356 |
+
|
| 357 |
+
console.print(f"\n[bold]Judge Model:[/] {self.judge_model.value}")
|
| 358 |
+
console.print(f"[bold]Test Samples:[/] {len(test_prompts)}")
|
| 359 |
+
|
| 360 |
+
ft_tokenizer = finetuned_tokenizer or tokenizer
|
| 361 |
+
verdicts = []
|
| 362 |
+
start_time = datetime.now()
|
| 363 |
+
|
| 364 |
+
with Progress(
|
| 365 |
+
SpinnerColumn(),
|
| 366 |
+
TextColumn("[progress.description]{task.description}"),
|
| 367 |
+
BarColumn(),
|
| 368 |
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
| 369 |
+
console=console
|
| 370 |
+
) as progress:
|
| 371 |
+
task = progress.add_task("Running arena battles...", total=len(test_prompts))
|
| 372 |
+
|
| 373 |
+
for i, prompt in enumerate(test_prompts):
|
| 374 |
+
progress.update(task, description=f"Battle {i+1}/{len(test_prompts)}...")
|
| 375 |
+
|
| 376 |
+
# Generate responses
|
| 377 |
+
response_a = self.generate_response(base_model, tokenizer, prompt)
|
| 378 |
+
response_b = self.generate_response(finetuned_model, ft_tokenizer, prompt)
|
| 379 |
+
|
| 380 |
+
# Get verdict
|
| 381 |
+
verdict = self.get_judge_verdict(prompt, response_a, response_b)
|
| 382 |
+
verdicts.append(verdict)
|
| 383 |
+
|
| 384 |
+
progress.advance(task)
|
| 385 |
+
|
| 386 |
+
evaluation_time = (datetime.now() - start_time).total_seconds()
|
| 387 |
+
|
| 388 |
+
# Calculate statistics
|
| 389 |
+
base_wins = sum(1 for v in verdicts if v.winner == "A")
|
| 390 |
+
ft_wins = sum(1 for v in verdicts if v.winner == "B")
|
| 391 |
+
ties = sum(1 for v in verdicts if v.winner == "TIE")
|
| 392 |
+
|
| 393 |
+
base_avg = sum(v.score_a for v in verdicts) / len(verdicts) if verdicts else 0
|
| 394 |
+
ft_avg = sum(v.score_b for v in verdicts) / len(verdicts) if verdicts else 0
|
| 395 |
+
|
| 396 |
+
win_rate = ft_wins / len(verdicts) if verdicts else 0
|
| 397 |
+
|
| 398 |
+
result = ArenaResult(
|
| 399 |
+
verdicts=verdicts,
|
| 400 |
+
base_model_wins=base_wins,
|
| 401 |
+
finetuned_wins=ft_wins,
|
| 402 |
+
ties=ties,
|
| 403 |
+
base_model_avg_score=base_avg,
|
| 404 |
+
finetuned_avg_score=ft_avg,
|
| 405 |
+
win_rate=win_rate,
|
| 406 |
+
total_comparisons=len(verdicts),
|
| 407 |
+
evaluation_time=evaluation_time,
|
| 408 |
+
judge_model=self.judge_model.value
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# Display results
|
| 412 |
+
self._display_results(result)
|
| 413 |
+
|
| 414 |
+
return result
|
| 415 |
+
|
| 416 |
+
def _display_results(self, result: ArenaResult):
|
| 417 |
+
"""Display arena results."""
|
| 418 |
+
console.print("\n" + "-"*40)
|
| 419 |
+
console.print("[bold]📊 ARENA RESULTS[/]")
|
| 420 |
+
console.print("-"*40)
|
| 421 |
+
|
| 422 |
+
# Win statistics
|
| 423 |
+
table = Table(title="Battle Statistics", show_header=True)
|
| 424 |
+
table.add_column("Metric", style="cyan")
|
| 425 |
+
table.add_column("Value", style="green")
|
| 426 |
+
|
| 427 |
+
table.add_row("Base Model Wins", str(result.base_model_wins))
|
| 428 |
+
table.add_row("Fine-tuned Wins", f"[bold green]{result.finetuned_wins}[/]")
|
| 429 |
+
table.add_row("Ties", str(result.ties))
|
| 430 |
+
table.add_row("Total Comparisons", str(result.total_comparisons))
|
| 431 |
+
table.add_row("Fine-tuned Win Rate", f"[bold]{result.win_rate:.1%}[/]")
|
| 432 |
+
|
| 433 |
+
console.print(table)
|
| 434 |
+
|
| 435 |
+
# Score comparison
|
| 436 |
+
table2 = Table(title="Average Scores (1-10)", show_header=True)
|
| 437 |
+
table2.add_column("Model", style="cyan")
|
| 438 |
+
table2.add_column("Score", style="green")
|
| 439 |
+
|
| 440 |
+
table2.add_row("Base Model", f"{result.base_model_avg_score:.2f}")
|
| 441 |
+
table2.add_row("Fine-tuned Model", f"[bold]{result.finetuned_avg_score:.2f}[/]")
|
| 442 |
+
table2.add_row("Improvement", f"+{result.finetuned_avg_score - result.base_model_avg_score:.2f}")
|
| 443 |
+
|
| 444 |
+
console.print(table2)
|
| 445 |
+
|
| 446 |
+
# Verdict
|
| 447 |
+
improvement_pct = ((result.finetuned_avg_score / result.base_model_avg_score) - 1) * 100 if result.base_model_avg_score > 0 else 0
|
| 448 |
+
|
| 449 |
+
if result.win_rate > 0.6:
|
| 450 |
+
verdict_text = f"[bold green]✅ SIGNIFICANT IMPROVEMENT[/]\nFine-tuned model wins {result.win_rate:.0%} of battles with {improvement_pct:.1f}% score improvement!"
|
| 451 |
+
elif result.win_rate > 0.4:
|
| 452 |
+
verdict_text = f"[bold yellow]⚖️ MARGINAL IMPROVEMENT[/]\nFine-tuned model shows moderate improvement ({result.win_rate:.0%} win rate)"
|
| 453 |
+
else:
|
| 454 |
+
verdict_text = f"[bold red]⚠️ NO IMPROVEMENT[/]\nFine-tuning did not improve model performance. Consider adjusting training data or hyperparameters."
|
| 455 |
+
|
| 456 |
+
console.print(Panel(verdict_text, title="Final Verdict", border_style="blue"))
|
| 457 |
+
|
| 458 |
+
def generate_report(
|
| 459 |
+
self,
|
| 460 |
+
result: ArenaResult,
|
| 461 |
+
output_path: str,
|
| 462 |
+
include_examples: int = 5
|
| 463 |
+
) -> str:
|
| 464 |
+
"""
|
| 465 |
+
Generate a detailed evaluation report.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
result: ArenaResult from run_arena
|
| 469 |
+
output_path: Path to save the report
|
| 470 |
+
include_examples: Number of example comparisons to include
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
Path to the generated report
|
| 474 |
+
"""
|
| 475 |
+
report_path = Path(output_path)
|
| 476 |
+
report_path.parent.mkdir(parents=True, exist_ok=True)
|
| 477 |
+
|
| 478 |
+
report = {
|
| 479 |
+
"timestamp": datetime.now().isoformat(),
|
| 480 |
+
"judge_model": result.judge_model,
|
| 481 |
+
"summary": {
|
| 482 |
+
"total_comparisons": result.total_comparisons,
|
| 483 |
+
"base_model_wins": result.base_model_wins,
|
| 484 |
+
"finetuned_wins": result.finetuned_wins,
|
| 485 |
+
"ties": result.ties,
|
| 486 |
+
"finetuned_win_rate": result.win_rate,
|
| 487 |
+
"base_model_avg_score": result.base_model_avg_score,
|
| 488 |
+
"finetuned_avg_score": result.finetuned_avg_score,
|
| 489 |
+
"score_improvement": result.finetuned_avg_score - result.base_model_avg_score,
|
| 490 |
+
"evaluation_time_seconds": result.evaluation_time
|
| 491 |
+
},
|
| 492 |
+
"example_verdicts": [
|
| 493 |
+
{
|
| 494 |
+
"prompt": v.prompt,
|
| 495 |
+
"response_base": v.response_a[:500] + "..." if len(v.response_a) > 500 else v.response_a,
|
| 496 |
+
"response_finetuned": v.response_b[:500] + "..." if len(v.response_b) > 500 else v.response_b,
|
| 497 |
+
"winner": v.winner,
|
| 498 |
+
"score_base": v.score_a,
|
| 499 |
+
"score_finetuned": v.score_b,
|
| 500 |
+
"reasoning": v.reasoning
|
| 501 |
+
}
|
| 502 |
+
for v in result.verdicts[:include_examples]
|
| 503 |
+
]
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
with open(report_path, 'w', encoding='utf-8') as f:
|
| 507 |
+
json.dump(report, f, indent=2, ensure_ascii=False)
|
| 508 |
+
|
| 509 |
+
console.print(f"\n[green]✓ Report saved to: {report_path}[/]")
|
| 510 |
+
|
| 511 |
+
return str(report_path)
|
| 512 |
+
|
| 513 |
+
def run_with_test_data(
|
| 514 |
+
self,
|
| 515 |
+
base_model: Any,
|
| 516 |
+
finetuned_model: Any,
|
| 517 |
+
tokenizer: Any,
|
| 518 |
+
test_data_path: str,
|
| 519 |
+
num_samples: int = 50,
|
| 520 |
+
finetuned_tokenizer: Optional[Any] = None
|
| 521 |
+
) -> ArenaResult:
|
| 522 |
+
"""
|
| 523 |
+
Run arena with test data from a JSONL file.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
base_model: Base model
|
| 527 |
+
finetuned_model: Fine-tuned model
|
| 528 |
+
tokenizer: Tokenizer
|
| 529 |
+
test_data_path: Path to test JSONL file
|
| 530 |
+
num_samples: Number of samples to evaluate
|
| 531 |
+
finetuned_tokenizer: Optional separate tokenizer
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
ArenaResult
|
| 535 |
+
"""
|
| 536 |
+
console.print(f"\n[blue]Loading test data from: {test_data_path}[/]")
|
| 537 |
+
|
| 538 |
+
prompts = []
|
| 539 |
+
with open(test_data_path, 'r', encoding='utf-8') as f:
|
| 540 |
+
for line in f:
|
| 541 |
+
data = json.loads(line)
|
| 542 |
+
prompts.append(data.get('instruction', data.get('prompt', '')))
|
| 543 |
+
|
| 544 |
+
# Sample if needed
|
| 545 |
+
if len(prompts) > num_samples:
|
| 546 |
+
prompts = random.sample(prompts, num_samples)
|
| 547 |
+
|
| 548 |
+
console.print(f"[green]✓ Loaded {len(prompts)} test prompts[/]")
|
| 549 |
+
|
| 550 |
+
return self.run_arena(
|
| 551 |
+
base_model,
|
| 552 |
+
finetuned_model,
|
| 553 |
+
tokenizer,
|
| 554 |
+
prompts,
|
| 555 |
+
finetuned_tokenizer
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
if __name__ == "__main__":
|
| 560 |
+
import sys
|
| 561 |
+
|
| 562 |
+
if len(sys.argv) < 4:
|
| 563 |
+
print("Usage: python the_judge.py <base_model_path> <finetuned_model_path> <test_data.jsonl>")
|
| 564 |
+
sys.exit(1)
|
| 565 |
+
|
| 566 |
+
print("TheJudge requires models to be loaded. See main.py for integrated usage.")
|
agents/training_pilot.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TrainingPilot - Automated Fine-Tuning Agent
|
| 3 |
+
=============================================
|
| 4 |
+
Uses Unsloth for ultra-fast LoRA fine-tuning with auto-configured
|
| 5 |
+
hyperparameters based on dataset size.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import yaml
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional, Dict, Any, Tuple
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
from rich.console import Console
|
| 16 |
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn
|
| 17 |
+
from rich.table import Table
|
| 18 |
+
from rich.panel import Panel
|
| 19 |
+
|
| 20 |
+
console = Console()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class HyperParams:
|
| 25 |
+
"""Training hyperparameters configuration."""
|
| 26 |
+
lora_rank: int = 16
|
| 27 |
+
lora_alpha: int = 32
|
| 28 |
+
learning_rate: float = 1e-4
|
| 29 |
+
num_epochs: int = 3
|
| 30 |
+
batch_size: int = 8
|
| 31 |
+
gradient_accumulation_steps: int = 2
|
| 32 |
+
warmup_ratio: float = 0.03
|
| 33 |
+
weight_decay: float = 0.01
|
| 34 |
+
max_grad_norm: float = 1.0
|
| 35 |
+
optimizer: str = "adamw_8bit"
|
| 36 |
+
lr_scheduler: str = "cosine"
|
| 37 |
+
gradient_checkpointing: bool = True
|
| 38 |
+
|
| 39 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 40 |
+
return {
|
| 41 |
+
'lora_rank': self.lora_rank,
|
| 42 |
+
'lora_alpha': self.lora_alpha,
|
| 43 |
+
'learning_rate': self.learning_rate,
|
| 44 |
+
'num_epochs': self.num_epochs,
|
| 45 |
+
'batch_size': self.batch_size,
|
| 46 |
+
'gradient_accumulation_steps': self.gradient_accumulation_steps,
|
| 47 |
+
'warmup_ratio': self.warmup_ratio,
|
| 48 |
+
'weight_decay': self.weight_decay,
|
| 49 |
+
'max_grad_norm': self.max_grad_norm,
|
| 50 |
+
'optimizer': self.optimizer,
|
| 51 |
+
'lr_scheduler': self.lr_scheduler,
|
| 52 |
+
'gradient_checkpointing': self.gradient_checkpointing
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class TrainingResult:
|
| 58 |
+
"""Results from a training run."""
|
| 59 |
+
model_path: str
|
| 60 |
+
training_time: float
|
| 61 |
+
final_loss: float
|
| 62 |
+
num_steps: int
|
| 63 |
+
hyperparams: HyperParams
|
| 64 |
+
dataset_size: int
|
| 65 |
+
metrics: Dict[str, Any]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TrainingPilot:
|
| 69 |
+
"""
|
| 70 |
+
Automated fine-tuning agent using Unsloth for ultra-fast LoRA training.
|
| 71 |
+
|
| 72 |
+
Features:
|
| 73 |
+
- Auto-configures hyperparameters based on dataset size
|
| 74 |
+
- Uses 4-bit quantization for memory efficiency
|
| 75 |
+
- Supports gradient checkpointing
|
| 76 |
+
- Automatic checkpoint saving
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# Dataset size thresholds
|
| 80 |
+
SMALL_THRESHOLD = 1000
|
| 81 |
+
MEDIUM_THRESHOLD = 10000
|
| 82 |
+
|
| 83 |
+
# Default target modules for LoRA
|
| 84 |
+
DEFAULT_TARGET_MODULES = [
|
| 85 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 86 |
+
"gate_proj", "up_proj", "down_proj"
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
config_path: Optional[str] = None,
|
| 92 |
+
base_model: str = "unsloth/llama-3-8b-bnb-4bit",
|
| 93 |
+
max_seq_length: int = 2048,
|
| 94 |
+
output_dir: str = "./output/models"
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Initialize the TrainingPilot.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
config_path: Path to config YAML file
|
| 101 |
+
base_model: HuggingFace model identifier
|
| 102 |
+
max_seq_length: Maximum sequence length
|
| 103 |
+
output_dir: Directory for saving models
|
| 104 |
+
"""
|
| 105 |
+
self.config = self._load_config(config_path)
|
| 106 |
+
self.base_model = base_model
|
| 107 |
+
self.max_seq_length = max_seq_length
|
| 108 |
+
self.output_dir = Path(output_dir)
|
| 109 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 110 |
+
|
| 111 |
+
self.model = None
|
| 112 |
+
self.tokenizer = None
|
| 113 |
+
self.trainer = None
|
| 114 |
+
|
| 115 |
+
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
|
| 116 |
+
"""Load configuration from YAML file."""
|
| 117 |
+
if config_path and Path(config_path).exists():
|
| 118 |
+
with open(config_path, 'r') as f:
|
| 119 |
+
return yaml.safe_load(f)
|
| 120 |
+
return {}
|
| 121 |
+
|
| 122 |
+
def auto_configure(self, dataset_size: int) -> HyperParams:
|
| 123 |
+
"""
|
| 124 |
+
Auto-configure hyperparameters based on dataset size.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
dataset_size: Number of training samples
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Optimized HyperParams configuration
|
| 131 |
+
"""
|
| 132 |
+
console.print(f"\n[bold blue]⚙️ Auto-configuring for {dataset_size:,} samples...[/]")
|
| 133 |
+
|
| 134 |
+
if dataset_size < self.SMALL_THRESHOLD:
|
| 135 |
+
# Small dataset: Higher learning rate, more epochs, smaller batch
|
| 136 |
+
params = HyperParams(
|
| 137 |
+
lora_rank=8,
|
| 138 |
+
lora_alpha=16,
|
| 139 |
+
learning_rate=2e-4,
|
| 140 |
+
num_epochs=5,
|
| 141 |
+
batch_size=4,
|
| 142 |
+
gradient_accumulation_steps=4
|
| 143 |
+
)
|
| 144 |
+
tier = "SMALL"
|
| 145 |
+
elif dataset_size < self.MEDIUM_THRESHOLD:
|
| 146 |
+
# Medium dataset: Balanced parameters
|
| 147 |
+
params = HyperParams(
|
| 148 |
+
lora_rank=16,
|
| 149 |
+
lora_alpha=32,
|
| 150 |
+
learning_rate=1e-4,
|
| 151 |
+
num_epochs=3,
|
| 152 |
+
batch_size=8,
|
| 153 |
+
gradient_accumulation_steps=2
|
| 154 |
+
)
|
| 155 |
+
tier = "MEDIUM"
|
| 156 |
+
else:
|
| 157 |
+
# Large dataset: Lower learning rate, fewer epochs, larger batch
|
| 158 |
+
params = HyperParams(
|
| 159 |
+
lora_rank=32,
|
| 160 |
+
lora_alpha=64,
|
| 161 |
+
learning_rate=5e-5,
|
| 162 |
+
num_epochs=2,
|
| 163 |
+
batch_size=16,
|
| 164 |
+
gradient_accumulation_steps=1
|
| 165 |
+
)
|
| 166 |
+
tier = "LARGE"
|
| 167 |
+
|
| 168 |
+
# Display configuration
|
| 169 |
+
table = Table(title=f"Auto-Configured Parameters [{tier}]", show_header=True)
|
| 170 |
+
table.add_column("Parameter", style="cyan")
|
| 171 |
+
table.add_column("Value", style="green")
|
| 172 |
+
|
| 173 |
+
for key, value in params.to_dict().items():
|
| 174 |
+
table.add_row(key, str(value))
|
| 175 |
+
|
| 176 |
+
console.print(table)
|
| 177 |
+
|
| 178 |
+
return params
|
| 179 |
+
|
| 180 |
+
def setup_model(
|
| 181 |
+
self,
|
| 182 |
+
hyperparams: HyperParams,
|
| 183 |
+
model_name: Optional[str] = None
|
| 184 |
+
) -> Tuple[Any, Any]:
|
| 185 |
+
"""
|
| 186 |
+
Setup the model with LoRA configuration using Unsloth.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
hyperparams: Training hyperparameters
|
| 190 |
+
model_name: Override model name
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Tuple of (model, tokenizer)
|
| 194 |
+
"""
|
| 195 |
+
console.print("\n[bold blue]🚀 Setting up model with Unsloth...[/]")
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
from unsloth import FastLanguageModel
|
| 199 |
+
except ImportError:
|
| 200 |
+
console.print("[red]❌ Unsloth not installed. Please install with: pip install unsloth[/]")
|
| 201 |
+
raise ImportError("Unsloth is required for training. Install with: pip install unsloth")
|
| 202 |
+
|
| 203 |
+
model_name = model_name or self.base_model
|
| 204 |
+
|
| 205 |
+
with Progress(
|
| 206 |
+
SpinnerColumn(),
|
| 207 |
+
TextColumn("[progress.description]{task.description}"),
|
| 208 |
+
console=console
|
| 209 |
+
) as progress:
|
| 210 |
+
task = progress.add_task("Loading model...", total=None)
|
| 211 |
+
|
| 212 |
+
# Load model with Unsloth
|
| 213 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 214 |
+
model_name=model_name,
|
| 215 |
+
max_seq_length=self.max_seq_length,
|
| 216 |
+
dtype=None, # Auto-detect
|
| 217 |
+
load_in_4bit=True,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
progress.update(task, description="Applying LoRA...")
|
| 221 |
+
|
| 222 |
+
# Apply LoRA with PEFT
|
| 223 |
+
model = FastLanguageModel.get_peft_model(
|
| 224 |
+
model,
|
| 225 |
+
r=hyperparams.lora_rank,
|
| 226 |
+
lora_alpha=hyperparams.lora_alpha,
|
| 227 |
+
lora_dropout=0,
|
| 228 |
+
target_modules=self.DEFAULT_TARGET_MODULES,
|
| 229 |
+
bias="none",
|
| 230 |
+
use_gradient_checkpointing="unsloth",
|
| 231 |
+
random_state=42,
|
| 232 |
+
use_rslora=True,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
progress.update(task, completed=True)
|
| 236 |
+
|
| 237 |
+
self.model = model
|
| 238 |
+
self.tokenizer = tokenizer
|
| 239 |
+
|
| 240 |
+
console.print("[green]✓ Model setup complete[/]")
|
| 241 |
+
self._print_model_info()
|
| 242 |
+
|
| 243 |
+
return model, tokenizer
|
| 244 |
+
|
| 245 |
+
def _print_model_info(self):
|
| 246 |
+
"""Print model information."""
|
| 247 |
+
if self.model is None:
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
# Calculate trainable parameters
|
| 251 |
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 252 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 253 |
+
|
| 254 |
+
console.print(Panel(
|
| 255 |
+
f"[bold]Model:[/] {self.base_model}\n"
|
| 256 |
+
f"[bold]Trainable Parameters:[/] {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)\n"
|
| 257 |
+
f"[bold]Total Parameters:[/] {total_params:,}",
|
| 258 |
+
title="Model Information",
|
| 259 |
+
border_style="blue"
|
| 260 |
+
))
|
| 261 |
+
|
| 262 |
+
def load_dataset(self, data_path: str) -> Any:
|
| 263 |
+
"""
|
| 264 |
+
Load and prepare the training dataset.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
data_path: Path to JSONL training file
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
HuggingFace Dataset object
|
| 271 |
+
"""
|
| 272 |
+
from datasets import load_dataset
|
| 273 |
+
|
| 274 |
+
console.print(f"\n[bold blue]📂 Loading training data:[/] {data_path}")
|
| 275 |
+
|
| 276 |
+
dataset = load_dataset('json', data_files=data_path, split='train')
|
| 277 |
+
console.print(f"[green]✓ Loaded {len(dataset):,} training samples[/]")
|
| 278 |
+
|
| 279 |
+
return dataset
|
| 280 |
+
|
| 281 |
+
def _format_prompts(self, dataset: Any) -> Any:
|
| 282 |
+
"""
|
| 283 |
+
Format dataset into training prompts.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
dataset: HuggingFace Dataset
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Formatted dataset
|
| 290 |
+
"""
|
| 291 |
+
# Alpaca-style prompt template
|
| 292 |
+
alpaca_template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
| 293 |
+
|
| 294 |
+
### Instruction:
|
| 295 |
+
{instruction}
|
| 296 |
+
|
| 297 |
+
### Input:
|
| 298 |
+
{input}
|
| 299 |
+
|
| 300 |
+
### Response:
|
| 301 |
+
{output}"""
|
| 302 |
+
|
| 303 |
+
alpaca_template_no_input = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| 304 |
+
|
| 305 |
+
### Instruction:
|
| 306 |
+
{instruction}
|
| 307 |
+
|
| 308 |
+
### Response:
|
| 309 |
+
{output}"""
|
| 310 |
+
|
| 311 |
+
def format_prompt(example):
|
| 312 |
+
if example.get('input') and len(str(example['input']).strip()) > 0:
|
| 313 |
+
text = alpaca_template.format(
|
| 314 |
+
instruction=example['instruction'],
|
| 315 |
+
input=example['input'],
|
| 316 |
+
output=example['output']
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
text = alpaca_template_no_input.format(
|
| 320 |
+
instruction=example['instruction'],
|
| 321 |
+
output=example['output']
|
| 322 |
+
)
|
| 323 |
+
return {"text": text}
|
| 324 |
+
|
| 325 |
+
return dataset.map(format_prompt)
|
| 326 |
+
|
| 327 |
+
def train(
|
| 328 |
+
self,
|
| 329 |
+
dataset: Any,
|
| 330 |
+
hyperparams: HyperParams,
|
| 331 |
+
output_name: Optional[str] = None
|
| 332 |
+
) -> TrainingResult:
|
| 333 |
+
"""
|
| 334 |
+
Run the fine-tuning training loop.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
dataset: Training dataset
|
| 338 |
+
hyperparams: Training hyperparameters
|
| 339 |
+
output_name: Custom name for output model
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
TrainingResult with metrics and model path
|
| 343 |
+
"""
|
| 344 |
+
from trl import SFTTrainer
|
| 345 |
+
from transformers import TrainingArguments
|
| 346 |
+
|
| 347 |
+
console.print("\n" + "="*60)
|
| 348 |
+
console.print("[bold magenta]🎯 TRAINING PILOT - STARTING TRAINING[/]")
|
| 349 |
+
console.print("="*60)
|
| 350 |
+
|
| 351 |
+
if self.model is None or self.tokenizer is None:
|
| 352 |
+
raise RuntimeError("Model not setup. Call setup_model() first.")
|
| 353 |
+
|
| 354 |
+
# Format dataset
|
| 355 |
+
console.print("\n[blue]Formatting training prompts...[/]")
|
| 356 |
+
formatted_dataset = self._format_prompts(dataset)
|
| 357 |
+
|
| 358 |
+
# Generate output name
|
| 359 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 360 |
+
output_name = output_name or f"finetuned_model_{timestamp}"
|
| 361 |
+
model_output_path = self.output_dir / output_name
|
| 362 |
+
|
| 363 |
+
# Setup training arguments
|
| 364 |
+
training_args = TrainingArguments(
|
| 365 |
+
output_dir=str(model_output_path),
|
| 366 |
+
num_train_epochs=hyperparams.num_epochs,
|
| 367 |
+
per_device_train_batch_size=hyperparams.batch_size,
|
| 368 |
+
gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,
|
| 369 |
+
learning_rate=hyperparams.learning_rate,
|
| 370 |
+
warmup_ratio=hyperparams.warmup_ratio,
|
| 371 |
+
weight_decay=hyperparams.weight_decay,
|
| 372 |
+
max_grad_norm=hyperparams.max_grad_norm,
|
| 373 |
+
lr_scheduler_type=hyperparams.lr_scheduler,
|
| 374 |
+
optim=hyperparams.optimizer,
|
| 375 |
+
fp16=True,
|
| 376 |
+
logging_steps=10,
|
| 377 |
+
save_strategy="epoch",
|
| 378 |
+
save_total_limit=2,
|
| 379 |
+
report_to="none",
|
| 380 |
+
seed=42,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Initialize trainer
|
| 384 |
+
trainer = SFTTrainer(
|
| 385 |
+
model=self.model,
|
| 386 |
+
tokenizer=self.tokenizer,
|
| 387 |
+
train_dataset=formatted_dataset,
|
| 388 |
+
dataset_text_field="text",
|
| 389 |
+
max_seq_length=self.max_seq_length,
|
| 390 |
+
args=training_args,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
self.trainer = trainer
|
| 394 |
+
|
| 395 |
+
# Train
|
| 396 |
+
console.print("\n[bold green]🏋️ Training in progress...[/]")
|
| 397 |
+
start_time = datetime.now()
|
| 398 |
+
|
| 399 |
+
train_result = trainer.train()
|
| 400 |
+
|
| 401 |
+
training_time = (datetime.now() - start_time).total_seconds()
|
| 402 |
+
|
| 403 |
+
# Save model
|
| 404 |
+
console.print("\n[blue]Saving model...[/]")
|
| 405 |
+
trainer.save_model(str(model_output_path))
|
| 406 |
+
self.tokenizer.save_pretrained(str(model_output_path))
|
| 407 |
+
|
| 408 |
+
# Get final metrics
|
| 409 |
+
final_loss = train_result.training_loss
|
| 410 |
+
num_steps = train_result.global_step
|
| 411 |
+
|
| 412 |
+
result = TrainingResult(
|
| 413 |
+
model_path=str(model_output_path),
|
| 414 |
+
training_time=training_time,
|
| 415 |
+
final_loss=final_loss,
|
| 416 |
+
num_steps=num_steps,
|
| 417 |
+
hyperparams=hyperparams,
|
| 418 |
+
dataset_size=len(dataset),
|
| 419 |
+
metrics=train_result.metrics
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Display results
|
| 423 |
+
self._display_results(result)
|
| 424 |
+
|
| 425 |
+
return result
|
| 426 |
+
|
| 427 |
+
def _display_results(self, result: TrainingResult):
|
| 428 |
+
"""Display training results."""
|
| 429 |
+
hours, remainder = divmod(result.training_time, 3600)
|
| 430 |
+
minutes, seconds = divmod(remainder, 60)
|
| 431 |
+
time_str = f"{int(hours)}h {int(minutes)}m {int(seconds)}s"
|
| 432 |
+
|
| 433 |
+
table = Table(title="Training Complete", show_header=True)
|
| 434 |
+
table.add_column("Metric", style="cyan")
|
| 435 |
+
table.add_column("Value", style="green")
|
| 436 |
+
|
| 437 |
+
table.add_row("Model Path", result.model_path)
|
| 438 |
+
table.add_row("Training Time", time_str)
|
| 439 |
+
table.add_row("Final Loss", f"{result.final_loss:.4f}")
|
| 440 |
+
table.add_row("Total Steps", str(result.num_steps))
|
| 441 |
+
table.add_row("Dataset Size", f"{result.dataset_size:,}")
|
| 442 |
+
|
| 443 |
+
console.print(table)
|
| 444 |
+
console.print("\n[bold green]✅ Training complete![/]")
|
| 445 |
+
|
| 446 |
+
def export_for_deployment(self, model_path: str, export_path: Optional[str] = None) -> str:
|
| 447 |
+
"""
|
| 448 |
+
Export the fine-tuned model for deployment.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
model_path: Path to the trained model
|
| 452 |
+
export_path: Custom export path
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
Path to exported model
|
| 456 |
+
"""
|
| 457 |
+
try:
|
| 458 |
+
from unsloth import FastLanguageModel
|
| 459 |
+
except ImportError:
|
| 460 |
+
raise ImportError("Unsloth is required for export")
|
| 461 |
+
|
| 462 |
+
console.print(f"\n[bold blue]📦 Exporting model for deployment...[/]")
|
| 463 |
+
|
| 464 |
+
export_path = export_path or str(Path(model_path) / "deployment")
|
| 465 |
+
|
| 466 |
+
# Load and merge LoRA weights
|
| 467 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 468 |
+
model_name=model_path,
|
| 469 |
+
max_seq_length=self.max_seq_length,
|
| 470 |
+
dtype=None,
|
| 471 |
+
load_in_4bit=True,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Save merged model
|
| 475 |
+
model.save_pretrained_merged(export_path, tokenizer, save_method="merged_16bit")
|
| 476 |
+
|
| 477 |
+
console.print(f"[green]✓ Exported to: {export_path}[/]")
|
| 478 |
+
|
| 479 |
+
return export_path
|
| 480 |
+
|
| 481 |
+
def run(
|
| 482 |
+
self,
|
| 483 |
+
data_path: str,
|
| 484 |
+
model_name: Optional[str] = None,
|
| 485 |
+
output_name: Optional[str] = None
|
| 486 |
+
) -> TrainingResult:
|
| 487 |
+
"""
|
| 488 |
+
Complete training pipeline.
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
data_path: Path to JSONL training data
|
| 492 |
+
model_name: Override base model
|
| 493 |
+
output_name: Custom name for output model
|
| 494 |
+
|
| 495 |
+
Returns:
|
| 496 |
+
TrainingResult
|
| 497 |
+
"""
|
| 498 |
+
console.print("\n" + "="*60)
|
| 499 |
+
console.print("[bold magenta]🧑✈️ TRAINING PILOT AGENT[/]")
|
| 500 |
+
console.print("="*60)
|
| 501 |
+
|
| 502 |
+
# Load dataset
|
| 503 |
+
dataset = self.load_dataset(data_path)
|
| 504 |
+
|
| 505 |
+
# Auto-configure hyperparameters
|
| 506 |
+
hyperparams = self.auto_configure(len(dataset))
|
| 507 |
+
|
| 508 |
+
# Setup model
|
| 509 |
+
self.setup_model(hyperparams, model_name)
|
| 510 |
+
|
| 511 |
+
# Train
|
| 512 |
+
result = self.train(dataset, hyperparams, output_name)
|
| 513 |
+
|
| 514 |
+
return result
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
if __name__ == "__main__":
|
| 518 |
+
import sys
|
| 519 |
+
|
| 520 |
+
if len(sys.argv) < 2:
|
| 521 |
+
print("Usage: python training_pilot.py <training_data.jsonl> [output_name]")
|
| 522 |
+
sys.exit(1)
|
| 523 |
+
|
| 524 |
+
data_path = sys.argv[1]
|
| 525 |
+
output_name = sys.argv[2] if len(sys.argv) > 2 else None
|
| 526 |
+
|
| 527 |
+
pilot = TrainingPilot()
|
| 528 |
+
result = pilot.run(data_path, output_name=output_name)
|
app.py
ADDED
|
@@ -0,0 +1,1500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Auto-FineTune-Ops: Streamlit Dashboard
|
| 3 |
+
======================================
|
| 4 |
+
Premium interactive dashboard for ML fine-tuning pipeline.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import streamlit as st
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import plotly.express as px
|
| 10 |
+
import plotly.graph_objects as go
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import sys
|
| 13 |
+
import os
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
+
# Add project root to path
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 20 |
+
|
| 21 |
+
# Page configuration
|
| 22 |
+
st.set_page_config(
|
| 23 |
+
page_title="Auto-FineTune-Ops",
|
| 24 |
+
page_icon="🤖",
|
| 25 |
+
layout="wide",
|
| 26 |
+
initial_sidebar_state="expanded"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Premium CSS styling
|
| 30 |
+
st.markdown("""
|
| 31 |
+
<style>
|
| 32 |
+
/* Main container */
|
| 33 |
+
.main .block-container {
|
| 34 |
+
padding-top: 2rem;
|
| 35 |
+
padding-bottom: 2rem;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
/* Cards */
|
| 39 |
+
.stMetric {
|
| 40 |
+
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
|
| 41 |
+
padding: 1rem;
|
| 42 |
+
border-radius: 12px;
|
| 43 |
+
border: 1px solid rgba(99, 102, 241, 0.2);
|
| 44 |
+
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
/* Gradient headers */
|
| 48 |
+
.gradient-header {
|
| 49 |
+
background: linear-gradient(90deg, #6366f1, #8b5cf6, #a855f7);
|
| 50 |
+
-webkit-background-clip: text;
|
| 51 |
+
-webkit-text-fill-color: transparent;
|
| 52 |
+
font-size: 2.5rem;
|
| 53 |
+
font-weight: 700;
|
| 54 |
+
margin-bottom: 1rem;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
/* Info cards */
|
| 58 |
+
.info-card {
|
| 59 |
+
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
|
| 60 |
+
padding: 1.5rem;
|
| 61 |
+
border-radius: 16px;
|
| 62 |
+
border: 1px solid rgba(99, 102, 241, 0.3);
|
| 63 |
+
margin: 1rem 0;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
/* Success badge */
|
| 67 |
+
.success-badge {
|
| 68 |
+
background: linear-gradient(90deg, #10b981, #059669);
|
| 69 |
+
color: white;
|
| 70 |
+
padding: 0.5rem 1rem;
|
| 71 |
+
border-radius: 20px;
|
| 72 |
+
font-weight: 600;
|
| 73 |
+
display: inline-block;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
/* Warning badge */
|
| 77 |
+
.warning-badge {
|
| 78 |
+
background: linear-gradient(90deg, #f59e0b, #d97706);
|
| 79 |
+
color: white;
|
| 80 |
+
padding: 0.5rem 1rem;
|
| 81 |
+
border-radius: 20px;
|
| 82 |
+
font-weight: 600;
|
| 83 |
+
display: inline-block;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/* Sidebar styling */
|
| 87 |
+
section[data-testid="stSidebar"] {
|
| 88 |
+
background: linear-gradient(180deg, #0f0f23 0%, #1a1a2e 100%);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/* Button styling */
|
| 92 |
+
.stButton > button {
|
| 93 |
+
background: linear-gradient(90deg, #6366f1, #8b5cf6);
|
| 94 |
+
color: white;
|
| 95 |
+
border: none;
|
| 96 |
+
border-radius: 8px;
|
| 97 |
+
padding: 0.5rem 2rem;
|
| 98 |
+
font-weight: 600;
|
| 99 |
+
transition: all 0.3s ease;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.stButton > button:hover {
|
| 103 |
+
transform: translateY(-2px);
|
| 104 |
+
box-shadow: 0 4px 20px rgba(99, 102, 241, 0.4);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
/* Progress bar */
|
| 108 |
+
.stProgress > div > div {
|
| 109 |
+
background: linear-gradient(90deg, #6366f1, #8b5cf6, #a855f7);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
/* Tab styling */
|
| 113 |
+
.stTabs [data-baseweb="tab-list"] {
|
| 114 |
+
gap: 8px;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
.stTabs [data-baseweb="tab"] {
|
| 118 |
+
background: rgba(99, 102, 241, 0.1);
|
| 119 |
+
border-radius: 8px;
|
| 120 |
+
padding: 0.5rem 1rem;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
.stTabs [aria-selected="true"] {
|
| 124 |
+
background: linear-gradient(90deg, #6366f1, #8b5cf6);
|
| 125 |
+
}
|
| 126 |
+
</style>
|
| 127 |
+
""", unsafe_allow_html=True)
|
| 128 |
+
|
| 129 |
+
# Initialize session state
|
| 130 |
+
if 'current_page' not in st.session_state:
|
| 131 |
+
st.session_state.current_page = 'home'
|
| 132 |
+
if 'uploaded_data' not in st.session_state:
|
| 133 |
+
st.session_state.uploaded_data = None
|
| 134 |
+
if 'processed_data_path' not in st.session_state:
|
| 135 |
+
st.session_state.processed_data_path = None
|
| 136 |
+
if 'model_path' not in st.session_state:
|
| 137 |
+
st.session_state.model_path = None
|
| 138 |
+
if 'training_goal' not in st.session_state:
|
| 139 |
+
st.session_state.training_goal = None
|
| 140 |
+
if 'pipeline_status' not in st.session_state:
|
| 141 |
+
st.session_state.pipeline_status = {
|
| 142 |
+
'data': 'pending',
|
| 143 |
+
'training': 'pending',
|
| 144 |
+
'evaluation': 'pending',
|
| 145 |
+
'deployment': 'pending'
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# Sidebar navigation
|
| 149 |
+
with st.sidebar:
|
| 150 |
+
st.markdown('<p class="gradient-header" style="font-size: 1.5rem;">🤖 Auto-FineTune-Ops</p>', unsafe_allow_html=True)
|
| 151 |
+
st.markdown("---")
|
| 152 |
+
|
| 153 |
+
# Navigation
|
| 154 |
+
pages = {
|
| 155 |
+
'home': ('🏠', 'Dashboard'),
|
| 156 |
+
'data': ('📊', 'Data Upload'),
|
| 157 |
+
'process': ('🧹', 'Processing'),
|
| 158 |
+
'training': ('🚀', 'Training'),
|
| 159 |
+
'evaluation': ('⚖️', 'Evaluation'),
|
| 160 |
+
'deploy': ('🌐', 'Deploy')
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
for key, (icon, label) in pages.items():
|
| 164 |
+
if st.button(f"{icon} {label}", key=f"nav_{key}", use_container_width=True):
|
| 165 |
+
st.session_state.current_page = key
|
| 166 |
+
|
| 167 |
+
st.markdown("---")
|
| 168 |
+
|
| 169 |
+
# Pipeline status
|
| 170 |
+
st.markdown("### 📋 Pipeline Status")
|
| 171 |
+
status_icons = {'pending': '⏳', 'running': '🔄', 'complete': '✅', 'error': '❌'}
|
| 172 |
+
for stage, status in st.session_state.pipeline_status.items():
|
| 173 |
+
st.markdown(f"{status_icons.get(status, '⏳')} **{stage.title()}**: {status}")
|
| 174 |
+
|
| 175 |
+
st.markdown("---")
|
| 176 |
+
st.markdown("*Built with ❤️ using Streamlit*")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ============================================================================
|
| 180 |
+
# PAGE: HOME DASHBOARD
|
| 181 |
+
# ============================================================================
|
| 182 |
+
def render_home():
|
| 183 |
+
st.markdown('<p class="gradient-header">🏠 Pipeline Dashboard</p>', unsafe_allow_html=True)
|
| 184 |
+
st.markdown("**One-click autonomous ML fine-tuning pipeline**")
|
| 185 |
+
|
| 186 |
+
# Status cards
|
| 187 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 188 |
+
|
| 189 |
+
with col1:
|
| 190 |
+
st.metric(
|
| 191 |
+
label="📊 Dataset",
|
| 192 |
+
value="Ready" if st.session_state.uploaded_data is not None else "Not Loaded",
|
| 193 |
+
delta="Uploaded" if st.session_state.uploaded_data is not None else None
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
with col2:
|
| 197 |
+
st.metric(
|
| 198 |
+
label="🧹 Processing",
|
| 199 |
+
value=st.session_state.pipeline_status['data'].title(),
|
| 200 |
+
delta="Complete" if st.session_state.pipeline_status['data'] == 'complete' else None
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
with col3:
|
| 204 |
+
st.metric(
|
| 205 |
+
label="🚀 Training",
|
| 206 |
+
value=st.session_state.pipeline_status['training'].title(),
|
| 207 |
+
delta="Complete" if st.session_state.pipeline_status['training'] == 'complete' else None
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
with col4:
|
| 211 |
+
st.metric(
|
| 212 |
+
label="⚖️ Evaluation",
|
| 213 |
+
value=st.session_state.pipeline_status['evaluation'].title(),
|
| 214 |
+
delta="Complete" if st.session_state.pipeline_status['evaluation'] == 'complete' else None
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
st.markdown("---")
|
| 218 |
+
|
| 219 |
+
# Quick start guide
|
| 220 |
+
st.markdown("### 🚀 Quick Start Guide")
|
| 221 |
+
|
| 222 |
+
col1, col2 = st.columns(2)
|
| 223 |
+
|
| 224 |
+
with col1:
|
| 225 |
+
st.markdown("""
|
| 226 |
+
<div class="info-card">
|
| 227 |
+
<h4>📊 Step 1: Upload Data</h4>
|
| 228 |
+
<p>Upload your CSV/JSON dataset with instruction-response pairs.</p>
|
| 229 |
+
</div>
|
| 230 |
+
""", unsafe_allow_html=True)
|
| 231 |
+
|
| 232 |
+
st.markdown("""
|
| 233 |
+
<div class="info-card">
|
| 234 |
+
<h4>🧹 Step 2: Process Data</h4>
|
| 235 |
+
<p>The DataArchitectAgent will clean and format your data.</p>
|
| 236 |
+
</div>
|
| 237 |
+
""", unsafe_allow_html=True)
|
| 238 |
+
|
| 239 |
+
with col2:
|
| 240 |
+
st.markdown("""
|
| 241 |
+
<div class="info-card">
|
| 242 |
+
<h4>🚀 Step 3: Train Model</h4>
|
| 243 |
+
<p>Fine-tune with auto-configured hyperparameters.</p>
|
| 244 |
+
</div>
|
| 245 |
+
""", unsafe_allow_html=True)
|
| 246 |
+
|
| 247 |
+
st.markdown("""
|
| 248 |
+
<div class="info-card">
|
| 249 |
+
<h4>⚖️ Step 4: Evaluate</h4>
|
| 250 |
+
<p>Run Model Arena with LLM-as-Judge evaluation.</p>
|
| 251 |
+
</div>
|
| 252 |
+
""", unsafe_allow_html=True)
|
| 253 |
+
|
| 254 |
+
# Recent output files
|
| 255 |
+
st.markdown("---")
|
| 256 |
+
st.markdown("### 📁 Output Files")
|
| 257 |
+
|
| 258 |
+
output_dir = Path("./output")
|
| 259 |
+
if output_dir.exists():
|
| 260 |
+
tabs = st.tabs(["📂 Models", "📊 Reports", "📝 Logs"])
|
| 261 |
+
|
| 262 |
+
with tabs[0]:
|
| 263 |
+
models_dir = output_dir / "models"
|
| 264 |
+
if models_dir.exists():
|
| 265 |
+
models = list(models_dir.glob("*"))
|
| 266 |
+
if models:
|
| 267 |
+
for model in models[:5]:
|
| 268 |
+
st.markdown(f"- 🤖 `{model.name}`")
|
| 269 |
+
else:
|
| 270 |
+
st.info("No trained models yet.")
|
| 271 |
+
else:
|
| 272 |
+
st.info("Models directory not found.")
|
| 273 |
+
|
| 274 |
+
with tabs[1]:
|
| 275 |
+
reports_dir = output_dir / "reports"
|
| 276 |
+
if reports_dir.exists():
|
| 277 |
+
reports = list(reports_dir.glob("*.json"))
|
| 278 |
+
if reports:
|
| 279 |
+
for report in reports[:5]:
|
| 280 |
+
st.markdown(f"- 📊 `{report.name}`")
|
| 281 |
+
else:
|
| 282 |
+
st.info("No evaluation reports yet.")
|
| 283 |
+
else:
|
| 284 |
+
st.info("Reports directory not found.")
|
| 285 |
+
|
| 286 |
+
with tabs[2]:
|
| 287 |
+
logs_dir = output_dir / "logs"
|
| 288 |
+
if logs_dir.exists():
|
| 289 |
+
logs = list(logs_dir.glob("*.yaml"))
|
| 290 |
+
if logs:
|
| 291 |
+
for log in logs[:5]:
|
| 292 |
+
st.markdown(f"- 📝 `{log.name}`")
|
| 293 |
+
else:
|
| 294 |
+
st.info("No log files yet.")
|
| 295 |
+
else:
|
| 296 |
+
st.info("Logs directory not found.")
|
| 297 |
+
else:
|
| 298 |
+
st.info("Output directory will be created when you run the pipeline.")
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ============================================================================
|
| 302 |
+
# PAGE: DATA UPLOAD
|
| 303 |
+
# ============================================================================
|
| 304 |
+
def render_data_upload():
|
| 305 |
+
st.markdown('<p class="gradient-header">📊 Data Upload & Preview</p>', unsafe_allow_html=True)
|
| 306 |
+
|
| 307 |
+
# ── File Management Bar ──
|
| 308 |
+
if st.session_state.uploaded_data is not None:
|
| 309 |
+
fm1, fm2, fm3 = st.columns([3, 1, 1])
|
| 310 |
+
with fm1:
|
| 311 |
+
st.info(f"📂 Currently loaded: **{st.session_state.get('uploaded_filename', 'dataset')}** ({len(st.session_state.uploaded_data):,} rows)")
|
| 312 |
+
with fm2:
|
| 313 |
+
if st.button("🗑️ Remove Dataset", type="secondary"):
|
| 314 |
+
st.session_state.uploaded_data = None
|
| 315 |
+
st.session_state.uploaded_filename = None
|
| 316 |
+
st.session_state.processed_data_path = None
|
| 317 |
+
st.session_state.pipeline_status['data'] = 'pending'
|
| 318 |
+
st.rerun()
|
| 319 |
+
with fm3:
|
| 320 |
+
if st.button("📎 Add More Data"):
|
| 321 |
+
st.session_state['show_add_file'] = True
|
| 322 |
+
|
| 323 |
+
# ── File Uploader ──
|
| 324 |
+
show_uploader = (st.session_state.uploaded_data is None) or st.session_state.get('show_add_file', False)
|
| 325 |
+
|
| 326 |
+
if show_uploader:
|
| 327 |
+
upload_label = "Upload your dataset (CSV, JSON, or JSONL)" if st.session_state.uploaded_data is None else "Upload additional file to merge with current dataset"
|
| 328 |
+
uploaded_file = st.file_uploader(
|
| 329 |
+
upload_label,
|
| 330 |
+
type=['csv', 'json', 'jsonl'],
|
| 331 |
+
help="Your dataset should contain instruction-response pairs.",
|
| 332 |
+
key=f"uploader_{st.session_state.get('upload_counter', 0)}"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if uploaded_file:
|
| 336 |
+
try:
|
| 337 |
+
if uploaded_file.name.endswith('.csv'):
|
| 338 |
+
new_df = pd.read_csv(uploaded_file)
|
| 339 |
+
elif uploaded_file.name.endswith('.jsonl'):
|
| 340 |
+
new_df = pd.read_json(uploaded_file, lines=True)
|
| 341 |
+
else:
|
| 342 |
+
new_df = pd.read_json(uploaded_file)
|
| 343 |
+
|
| 344 |
+
# Merge or replace
|
| 345 |
+
if st.session_state.uploaded_data is not None and st.session_state.get('show_add_file', False):
|
| 346 |
+
existing_df = st.session_state.uploaded_data
|
| 347 |
+
if list(new_df.columns) == list(existing_df.columns):
|
| 348 |
+
st.session_state.uploaded_data = pd.concat([existing_df, new_df], ignore_index=True)
|
| 349 |
+
st.session_state.uploaded_filename = f"{st.session_state.get('uploaded_filename', 'data')} + {uploaded_file.name}"
|
| 350 |
+
st.success(f"✅ Merged **{uploaded_file.name}** ({len(new_df):,} rows) → Total: **{len(st.session_state.uploaded_data):,}** rows")
|
| 351 |
+
else:
|
| 352 |
+
st.error(f"❌ Column mismatch! Existing: {list(existing_df.columns)} vs New: {list(new_df.columns)}")
|
| 353 |
+
else:
|
| 354 |
+
st.session_state.uploaded_data = new_df
|
| 355 |
+
st.session_state.uploaded_filename = uploaded_file.name
|
| 356 |
+
st.success(f"✅ Successfully loaded **{uploaded_file.name}**")
|
| 357 |
+
|
| 358 |
+
st.session_state['show_add_file'] = False
|
| 359 |
+
st.session_state['upload_counter'] = st.session_state.get('upload_counter', 0) + 1
|
| 360 |
+
|
| 361 |
+
except Exception as e:
|
| 362 |
+
st.error(f"Error loading file: {str(e)}")
|
| 363 |
+
|
| 364 |
+
# ── Data Display ──
|
| 365 |
+
if st.session_state.uploaded_data is not None:
|
| 366 |
+
df = st.session_state.uploaded_data
|
| 367 |
+
|
| 368 |
+
# Dataset statistics
|
| 369 |
+
st.markdown("### 📈 Dataset Statistics")
|
| 370 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 371 |
+
with col1:
|
| 372 |
+
st.metric("Total Rows", f"{len(df):,}")
|
| 373 |
+
with col2:
|
| 374 |
+
st.metric("Total Columns", len(df.columns))
|
| 375 |
+
with col3:
|
| 376 |
+
total_bytes = df.memory_usage(deep=True).sum()
|
| 377 |
+
st.metric("Memory Size", f"{total_bytes / 1024:.1f} KB")
|
| 378 |
+
with col4:
|
| 379 |
+
missing = df.isnull().sum().sum()
|
| 380 |
+
st.metric("Missing Values", missing)
|
| 381 |
+
|
| 382 |
+
st.markdown("---")
|
| 383 |
+
|
| 384 |
+
# Column detection
|
| 385 |
+
st.markdown("### 🔍 Auto-Detected Columns")
|
| 386 |
+
instruction_patterns = ['instruction', 'prompt', 'question', 'query', 'user', 'input_text']
|
| 387 |
+
output_patterns = ['output', 'response', 'answer', 'completion', 'assistant', 'target']
|
| 388 |
+
|
| 389 |
+
detected_instruction = None
|
| 390 |
+
detected_output = None
|
| 391 |
+
|
| 392 |
+
for col in df.columns:
|
| 393 |
+
col_lower = col.lower()
|
| 394 |
+
for pattern in instruction_patterns:
|
| 395 |
+
if pattern in col_lower and not detected_instruction:
|
| 396 |
+
detected_instruction = col
|
| 397 |
+
for pattern in output_patterns:
|
| 398 |
+
if pattern in col_lower and not detected_output:
|
| 399 |
+
detected_output = col
|
| 400 |
+
|
| 401 |
+
col1, col2 = st.columns(2)
|
| 402 |
+
with col1:
|
| 403 |
+
if detected_instruction:
|
| 404 |
+
st.markdown(f'<span class="success-badge">Instruction: {detected_instruction}</span>', unsafe_allow_html=True)
|
| 405 |
+
else:
|
| 406 |
+
st.markdown(f'<span class="warning-badge">Instruction: Not detected</span>', unsafe_allow_html=True)
|
| 407 |
+
with col2:
|
| 408 |
+
if detected_output:
|
| 409 |
+
st.markdown(f'<span class="success-badge">Output: {detected_output}</span>', unsafe_allow_html=True)
|
| 410 |
+
else:
|
| 411 |
+
st.markdown(f'<span class="warning-badge">Output: Not detected</span>', unsafe_allow_html=True)
|
| 412 |
+
|
| 413 |
+
st.markdown("---")
|
| 414 |
+
|
| 415 |
+
# Full data preview (scrollable)
|
| 416 |
+
st.markdown("### 👀 Complete Data Preview")
|
| 417 |
+
st.caption(f"Showing all **{len(df):,}** rows. Scroll to browse the full dataset.")
|
| 418 |
+
st.dataframe(df, use_container_width=True, height=450)
|
| 419 |
+
|
| 420 |
+
# Download raw data
|
| 421 |
+
st.markdown("### 📥 Download Dataset")
|
| 422 |
+
dl1, dl2 = st.columns(2)
|
| 423 |
+
with dl1:
|
| 424 |
+
csv_data = df.to_csv(index=False).encode('utf-8')
|
| 425 |
+
st.download_button("⬇️ Download as CSV", csv_data,
|
| 426 |
+
file_name=f"{st.session_state.get('uploaded_filename', 'dataset').rsplit('.', 1)[0]}.csv",
|
| 427 |
+
mime="text/csv")
|
| 428 |
+
with dl2:
|
| 429 |
+
json_data = df.to_json(orient='records', indent=2).encode('utf-8')
|
| 430 |
+
st.download_button("⬇️ Download as JSON", json_data,
|
| 431 |
+
file_name=f"{st.session_state.get('uploaded_filename', 'dataset').rsplit('.', 1)[0]}.json",
|
| 432 |
+
mime="application/json")
|
| 433 |
+
|
| 434 |
+
# Column summary
|
| 435 |
+
st.markdown("### 📋 Column Summary")
|
| 436 |
+
col_info = []
|
| 437 |
+
for col in df.columns:
|
| 438 |
+
col_info.append({
|
| 439 |
+
'Column': col,
|
| 440 |
+
'Type': str(df[col].dtype),
|
| 441 |
+
'Non-Null': df[col].notna().sum(),
|
| 442 |
+
'Unique': df[col].nunique(),
|
| 443 |
+
'Sample': str(df[col].iloc[0])[:80] + '...' if len(str(df[col].iloc[0])) > 80 else str(df[col].iloc[0])
|
| 444 |
+
})
|
| 445 |
+
st.dataframe(pd.DataFrame(col_info), use_container_width=True)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# ============================================================================
|
| 449 |
+
# PAGE: DATA PROCESSING
|
| 450 |
+
# ============================================================================
|
| 451 |
+
def render_processing():
|
| 452 |
+
st.markdown('<p class="gradient-header">🧹 Advanced Data Processing</p>', unsafe_allow_html=True)
|
| 453 |
+
|
| 454 |
+
if st.session_state.uploaded_data is None:
|
| 455 |
+
st.warning("⚠️ Please upload a dataset first!")
|
| 456 |
+
if st.button("📊 Go to Data Upload"):
|
| 457 |
+
st.session_state.current_page = 'data'
|
| 458 |
+
st.rerun()
|
| 459 |
+
return
|
| 460 |
+
|
| 461 |
+
df = st.session_state.uploaded_data
|
| 462 |
+
|
| 463 |
+
# ── Dataset Stats Header ──
|
| 464 |
+
st.markdown("### 📈 Dataset Statistics")
|
| 465 |
+
sc1, sc2, sc3, sc4 = st.columns(4)
|
| 466 |
+
with sc1:
|
| 467 |
+
st.metric("Total Rows", f"{len(df):,}")
|
| 468 |
+
with sc2:
|
| 469 |
+
st.metric("Columns", len(df.columns))
|
| 470 |
+
with sc3:
|
| 471 |
+
avg_len = int(df.iloc[:, 0].astype(str).str.len().mean()) if len(df) > 0 else 0
|
| 472 |
+
st.metric("Avg Text Length", f"{avg_len:,} chars")
|
| 473 |
+
with sc4:
|
| 474 |
+
est_tokens = int(avg_len * len(df) / 4) if avg_len > 0 else 0
|
| 475 |
+
st.metric("Est. Total Tokens", f"{est_tokens:,}")
|
| 476 |
+
|
| 477 |
+
st.markdown("---")
|
| 478 |
+
|
| 479 |
+
# ── Training Goal ──
|
| 480 |
+
goal = st.text_input(
|
| 481 |
+
"Training Goal",
|
| 482 |
+
value=st.session_state.training_goal or "assistant",
|
| 483 |
+
help="e.g., medical_assistant, customer_support, code_helper"
|
| 484 |
+
)
|
| 485 |
+
st.session_state.training_goal = goal
|
| 486 |
+
|
| 487 |
+
# ── Column Mapping ──
|
| 488 |
+
st.markdown("### 🎯 Column Mapping")
|
| 489 |
+
instruction_patterns = ['instruction', 'prompt', 'question', 'query', 'user', 'input_text', 'human']
|
| 490 |
+
output_patterns = ['output', 'response', 'answer', 'completion', 'assistant', 'target']
|
| 491 |
+
input_patterns = ['context', 'input', 'background', 'reference']
|
| 492 |
+
|
| 493 |
+
detected_instruction = detected_output = detected_input = None
|
| 494 |
+
available_columns = list(df.columns)
|
| 495 |
+
|
| 496 |
+
for col in available_columns:
|
| 497 |
+
col_lower = col.lower()
|
| 498 |
+
for p in instruction_patterns:
|
| 499 |
+
if p in col_lower and not detected_instruction:
|
| 500 |
+
detected_instruction = col
|
| 501 |
+
for p in output_patterns:
|
| 502 |
+
if p in col_lower and not detected_output:
|
| 503 |
+
detected_output = col
|
| 504 |
+
for p in input_patterns:
|
| 505 |
+
if p in col_lower and not detected_input:
|
| 506 |
+
detected_input = col
|
| 507 |
+
|
| 508 |
+
mc1, mc2, mc3 = st.columns(3)
|
| 509 |
+
with mc1:
|
| 510 |
+
instruction_col = st.selectbox("Instruction Column *", options=available_columns,
|
| 511 |
+
index=available_columns.index(detected_instruction) if detected_instruction else 0,
|
| 512 |
+
help="Column containing instructions/prompts/questions")
|
| 513 |
+
with mc2:
|
| 514 |
+
output_col = st.selectbox("Output Column *", options=available_columns,
|
| 515 |
+
index=available_columns.index(detected_output) if detected_output else (1 if len(available_columns) > 1 else 0),
|
| 516 |
+
help="Column containing responses/answers/outputs")
|
| 517 |
+
with mc3:
|
| 518 |
+
input_col_options = ["None"] + available_columns
|
| 519 |
+
default_input_idx = input_col_options.index(detected_input) if detected_input else 0
|
| 520 |
+
input_col_selection = st.selectbox("Input/Context Column (Optional)", options=input_col_options,
|
| 521 |
+
index=default_input_idx, help="Optional column containing additional context")
|
| 522 |
+
input_col = None if input_col_selection == "None" else input_col_selection
|
| 523 |
+
|
| 524 |
+
st.markdown("---")
|
| 525 |
+
|
| 526 |
+
# ── Safe Preset Button ──
|
| 527 |
+
if st.button("🛡️ Load Safe Preset", help="Apply recommended defaults for most datasets"):
|
| 528 |
+
st.session_state['safe_preset'] = True
|
| 529 |
+
st.rerun()
|
| 530 |
+
|
| 531 |
+
use_safe = st.session_state.get('safe_preset', False)
|
| 532 |
+
|
| 533 |
+
# ====================================================================
|
| 534 |
+
# 1️⃣ Text Cleaning Controls
|
| 535 |
+
# ====================================================================
|
| 536 |
+
with st.expander("1️⃣ Text Cleaning Controls", expanded=False):
|
| 537 |
+
tc1, tc2 = st.columns(2)
|
| 538 |
+
with tc1:
|
| 539 |
+
clean_html = st.checkbox("Remove HTML Tags", value=use_safe, help="Strip all HTML/XML tags from text")
|
| 540 |
+
clean_urls = st.checkbox("Remove URLs", value=use_safe, help="Remove http/https/www links")
|
| 541 |
+
clean_emojis = st.checkbox("Remove Emojis", value=False, help="Strip emoji characters")
|
| 542 |
+
clean_whitespace = st.checkbox("Normalize Whitespace", value=True, help="Collapse multiple spaces/tabs into one")
|
| 543 |
+
with tc2:
|
| 544 |
+
clean_lowercase = st.checkbox("Lowercase All Text", value=False, help="Convert text to lowercase (disable to preserve case)")
|
| 545 |
+
clean_special = st.checkbox("Remove Special Characters", value=False, help="Keep only alphanumeric + basic punctuation")
|
| 546 |
+
clean_linebreaks = st.checkbox("Strip Extra Line Breaks", value=True, help="Reduce 3+ newlines to double newlines")
|
| 547 |
+
|
| 548 |
+
# ====================================================================
|
| 549 |
+
# 2️⃣ Tokenization Controls
|
| 550 |
+
# ====================================================================
|
| 551 |
+
with st.expander("2️⃣ Tokenization Controls", expanded=False):
|
| 552 |
+
tk1, tk2 = st.columns(2)
|
| 553 |
+
with tk1:
|
| 554 |
+
tokenizer_choice = st.selectbox("Tokenizer", ["tiktoken", "HuggingFace"],
|
| 555 |
+
help="tiktoken = OpenAI-compatible, HuggingFace = model-specific tokenizer")
|
| 556 |
+
if tokenizer_choice == "HuggingFace":
|
| 557 |
+
hf_model_name = st.text_input("HF Model Name", value="meta-llama/Llama-3-8b",
|
| 558 |
+
help="HuggingFace model name for tokenizer")
|
| 559 |
+
else:
|
| 560 |
+
hf_model_name = ""
|
| 561 |
+
max_total_tokens = st.slider("Max Tokens per Sample", 128, 8192, 2048,
|
| 562 |
+
help="Maximum total tokens allowed per sample")
|
| 563 |
+
with tk2:
|
| 564 |
+
truncate_long = st.checkbox("Truncate Long Samples", value=False,
|
| 565 |
+
help="Cut text exceeding max tokens")
|
| 566 |
+
split_long = st.checkbox("Split Long Samples into Chunks", value=False,
|
| 567 |
+
help="Break long texts into overlapping chunks")
|
| 568 |
+
if split_long:
|
| 569 |
+
split_overlap = st.slider("Chunk Overlap Tokens", 0, 200, 50,
|
| 570 |
+
help="Number of overlapping tokens between chunks")
|
| 571 |
+
else:
|
| 572 |
+
split_overlap = 50
|
| 573 |
+
|
| 574 |
+
# Token stats preview
|
| 575 |
+
if st.button("📊 Show Token Stats Preview", key="token_stats_btn"):
|
| 576 |
+
with st.spinner("Counting tokens..."):
|
| 577 |
+
try:
|
| 578 |
+
from preprocessing.tokenization import TokenizationConfig, get_tokenizer, compute_token_stats
|
| 579 |
+
tk_cfg = TokenizationConfig(
|
| 580 |
+
tokenizer_name="tiktoken" if tokenizer_choice == "tiktoken" else hf_model_name,
|
| 581 |
+
)
|
| 582 |
+
tokenizer = get_tokenizer(tk_cfg)
|
| 583 |
+
is_tiktoken = tokenizer_choice == "tiktoken"
|
| 584 |
+
stats_cols = [c for c in [instruction_col, output_col] if c in df.columns]
|
| 585 |
+
stats = compute_token_stats(df.head(200), stats_cols, tokenizer, is_tiktoken)
|
| 586 |
+
for col_name, s in stats.items():
|
| 587 |
+
st.markdown(f"**{col_name}**: min={s['min']}, max={s['max']}, mean={s['mean']}, p95={s['p95']}")
|
| 588 |
+
except Exception as e:
|
| 589 |
+
st.warning(f"Could not compute token stats: {e}")
|
| 590 |
+
|
| 591 |
+
# ====================================================================
|
| 592 |
+
# 3️⃣ System Prompt Configuration
|
| 593 |
+
# ====================================================================
|
| 594 |
+
with st.expander("3️⃣ System Prompt Configuration", expanded=False):
|
| 595 |
+
system_prompt_text = st.text_area("Global System Prompt",
|
| 596 |
+
value="You are a helpful AI assistant." if not use_safe else "You are a helpful AI assistant.",
|
| 597 |
+
height=100, help="System prompt prepended to every sample in chat format")
|
| 598 |
+
prepend_system = st.checkbox("Prepend System Prompt to All Samples", value=True,
|
| 599 |
+
help="Include this system prompt in all formatted entries")
|
| 600 |
+
|
| 601 |
+
if st.button("👁️ Preview Formatted Chat JSON", key="preview_chat_btn"):
|
| 602 |
+
try:
|
| 603 |
+
from preprocessing.system_prompt import preview_formatted_json
|
| 604 |
+
preview = preview_formatted_json(df, system_prompt_text, instruction_col, output_col, input_col, n=2)
|
| 605 |
+
st.code(preview, language="json")
|
| 606 |
+
except Exception as e:
|
| 607 |
+
st.warning(f"Preview error: {e}")
|
| 608 |
+
|
| 609 |
+
# ====================================================================
|
| 610 |
+
# 4️⃣ Dataset Balancing
|
| 611 |
+
# ====================================================================
|
| 612 |
+
with st.expander("4️⃣ Dataset Balancing (Classification)", expanded=False):
|
| 613 |
+
balance_enabled = st.checkbox("Enable Class Balancing", value=False,
|
| 614 |
+
help="Balance class distribution for classification tasks")
|
| 615 |
+
if balance_enabled:
|
| 616 |
+
label_col_options = available_columns
|
| 617 |
+
label_col = st.selectbox("Label Column", options=label_col_options,
|
| 618 |
+
help="Column containing class labels")
|
| 619 |
+
balance_strategy = st.radio("Strategy", ["none", "oversample", "undersample"],
|
| 620 |
+
help="Oversample = duplicate minority, Undersample = drop majority")
|
| 621 |
+
|
| 622 |
+
# Show distribution chart
|
| 623 |
+
if label_col in df.columns:
|
| 624 |
+
from preprocessing.dataset_balancing import compute_label_distribution
|
| 625 |
+
dist = compute_label_distribution(df, label_col)
|
| 626 |
+
if dist:
|
| 627 |
+
fig = px.bar(x=list(dist.keys()), y=list(dist.values()),
|
| 628 |
+
labels={'x': 'Label', 'y': 'Count'}, title="Label Distribution")
|
| 629 |
+
fig.update_layout(paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
|
| 630 |
+
font_color='#e2e8f0')
|
| 631 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 632 |
+
else:
|
| 633 |
+
label_col = None
|
| 634 |
+
balance_strategy = "none"
|
| 635 |
+
|
| 636 |
+
# ====================================================================
|
| 637 |
+
# 5️⃣ Quality Filters
|
| 638 |
+
# ====================================================================
|
| 639 |
+
with st.expander("5️⃣ Quality Filters", expanded=False):
|
| 640 |
+
qf1, qf2 = st.columns(2)
|
| 641 |
+
with qf1:
|
| 642 |
+
min_words = st.number_input("Min Word Count", min_value=0, value=3 if use_safe else 0,
|
| 643 |
+
help="Minimum words required per sample (0 = no filter)")
|
| 644 |
+
max_words = st.number_input("Max Word Count", min_value=0, value=0,
|
| 645 |
+
help="Maximum words allowed per sample (0 = no limit)")
|
| 646 |
+
profanity_filter = st.checkbox("Profanity Filter", value=False,
|
| 647 |
+
help="Remove samples containing profane language")
|
| 648 |
+
with qf2:
|
| 649 |
+
language_filter = st.checkbox("Language Detection Filter", value=False,
|
| 650 |
+
help="Keep only samples in specified languages")
|
| 651 |
+
if language_filter:
|
| 652 |
+
allowed_langs = st.text_input("Allowed Languages (comma-separated)", value="en",
|
| 653 |
+
help="ISO 639-1 codes, e.g. en,fr,de")
|
| 654 |
+
else:
|
| 655 |
+
allowed_langs = "en"
|
| 656 |
+
remove_low_quality = st.checkbox("Remove Low-Quality Responses", value=use_safe,
|
| 657 |
+
help="Remove short / generic / placeholder responses")
|
| 658 |
+
|
| 659 |
+
# ====================================================================
|
| 660 |
+
# 6️⃣ Deduplication Advanced
|
| 661 |
+
# ====================================================================
|
| 662 |
+
with st.expander("6️⃣ Deduplication", expanded=False):
|
| 663 |
+
dedup_exact = st.checkbox("Remove Exact Duplicates", value=True,
|
| 664 |
+
help="Remove rows with identical instruction text")
|
| 665 |
+
dedup_semantic = st.checkbox("Remove Semantic Duplicates", value=False,
|
| 666 |
+
help="Use TF-IDF cosine similarity to find near-duplicates")
|
| 667 |
+
if dedup_semantic:
|
| 668 |
+
semantic_threshold = st.slider("Similarity Threshold", 0.5, 1.0, 0.90, 0.01,
|
| 669 |
+
help="Cosine similarity above this threshold = duplicate (higher = stricter)")
|
| 670 |
+
else:
|
| 671 |
+
semantic_threshold = 0.90
|
| 672 |
+
|
| 673 |
+
# ====================================================================
|
| 674 |
+
# 7️⃣ Train / Validation Split
|
| 675 |
+
# ====================================================================
|
| 676 |
+
with st.expander("7️⃣ Train / Validation Split", expanded=False):
|
| 677 |
+
split_enabled = st.checkbox("Enable Train/Val Split", value=True,
|
| 678 |
+
help="Split dataset into training and validation sets")
|
| 679 |
+
if split_enabled:
|
| 680 |
+
train_ratio = st.slider("Train Ratio", 0.5, 0.95, 0.9 if use_safe else 0.8, 0.05,
|
| 681 |
+
help="Proportion of data used for training")
|
| 682 |
+
st.markdown(f"**Split**: {int(train_ratio*100)}% Train / {int((1-train_ratio)*100)}% Validation")
|
| 683 |
+
random_seed = st.number_input("Random Seed", min_value=0, value=42,
|
| 684 |
+
help="Seed for reproducible splits")
|
| 685 |
+
shuffle_data = st.checkbox("Shuffle Before Split", value=True,
|
| 686 |
+
help="Randomly shuffle data before splitting")
|
| 687 |
+
else:
|
| 688 |
+
train_ratio = 0.8
|
| 689 |
+
random_seed = 42
|
| 690 |
+
shuffle_data = True
|
| 691 |
+
|
| 692 |
+
# ====================================================================
|
| 693 |
+
# 8️⃣ Output Formatting
|
| 694 |
+
# ====================================================================
|
| 695 |
+
with st.expander("8️⃣ Output Formatting", expanded=False):
|
| 696 |
+
format_type = st.selectbox("Export Format", ["openai_chat", "completion", "classification", "custom"],
|
| 697 |
+
help="OpenAI Chat = messages format, Completion = prompt/completion, Classification = text/label")
|
| 698 |
+
|
| 699 |
+
custom_schema = {}
|
| 700 |
+
if format_type == "custom":
|
| 701 |
+
st.markdown("**Define Custom Schema** (output_key → source_column)")
|
| 702 |
+
num_fields = st.number_input("Number of Fields", 1, 10, 2)
|
| 703 |
+
for i in range(int(num_fields)):
|
| 704 |
+
fc1, fc2 = st.columns(2)
|
| 705 |
+
with fc1:
|
| 706 |
+
key = st.text_input(f"Output Key {i+1}", value=f"field_{i+1}", key=f"ckey_{i}")
|
| 707 |
+
with fc2:
|
| 708 |
+
val = st.selectbox(f"Source Column {i+1}", options=available_columns, key=f"cval_{i}")
|
| 709 |
+
custom_schema[key] = val
|
| 710 |
+
|
| 711 |
+
# ====================================================================
|
| 712 |
+
# 9️⃣ Safety & PII Filtering
|
| 713 |
+
# ====================================================================
|
| 714 |
+
with st.expander("9️⃣ Safety & PII Filtering", expanded=False):
|
| 715 |
+
pii1, pii2 = st.columns(2)
|
| 716 |
+
with pii1:
|
| 717 |
+
pii_emails = st.checkbox("Detect & Mask Emails", value=use_safe,
|
| 718 |
+
help="Replace email addresses with [REDACTED]")
|
| 719 |
+
pii_phones = st.checkbox("Detect & Mask Phone Numbers", value=use_safe,
|
| 720 |
+
help="Replace phone numbers with [REDACTED]")
|
| 721 |
+
pii_ids = st.checkbox("Detect & Mask CNIC/SSN", value=use_safe,
|
| 722 |
+
help="Replace national ID / SSN patterns with [REDACTED]")
|
| 723 |
+
with pii2:
|
| 724 |
+
pii_keys = st.checkbox("Detect & Mask API Keys", value=use_safe,
|
| 725 |
+
help="Replace long hex/base64 strings that look like secrets")
|
| 726 |
+
pii_addresses = st.checkbox("Detect & Mask Addresses", value=False,
|
| 727 |
+
help="Replace street addresses and zip codes")
|
| 728 |
+
|
| 729 |
+
# ====================================================================
|
| 730 |
+
# 🔟 Augmentation (Optional)
|
| 731 |
+
# ====================================================================
|
| 732 |
+
with st.expander("🔟 Augmentation (Optional)", expanded=False):
|
| 733 |
+
aug_enabled = st.checkbox("Enable Data Augmentation", value=False,
|
| 734 |
+
help="Generate synthetic variations of existing samples")
|
| 735 |
+
if aug_enabled:
|
| 736 |
+
ag1, ag2 = st.columns(2)
|
| 737 |
+
with ag1:
|
| 738 |
+
aug_paraphrase = st.checkbox("Paraphrase Instructions", value=True,
|
| 739 |
+
help="Synonym-based paraphrasing of instructions")
|
| 740 |
+
aug_variations = st.checkbox("Generate Variations", value=False,
|
| 741 |
+
help="Minor text variations (punctuation, casing)")
|
| 742 |
+
with ag2:
|
| 743 |
+
aug_backtranslate = st.checkbox("Back Translation", value=False,
|
| 744 |
+
help="Simulate back-translation for diversity")
|
| 745 |
+
aug_tone = st.checkbox("Tone Rewriting", value=False,
|
| 746 |
+
help="Rewrite instructions in different tones")
|
| 747 |
+
aug_factor = st.slider("Augmentation Factor", 1, 5, 1,
|
| 748 |
+
help="Number of augmented copies per original sample")
|
| 749 |
+
else:
|
| 750 |
+
aug_paraphrase = aug_variations = aug_backtranslate = aug_tone = False
|
| 751 |
+
aug_factor = 1
|
| 752 |
+
|
| 753 |
+
st.markdown("---")
|
| 754 |
+
|
| 755 |
+
# ── Run Pipeline Button ──
|
| 756 |
+
if st.button("🚀 Run Advanced Processing Pipeline", type="primary", use_container_width=True):
|
| 757 |
+
st.session_state.pipeline_status['data'] = 'running'
|
| 758 |
+
|
| 759 |
+
with st.spinner("Running preprocessing pipeline..."):
|
| 760 |
+
progress_bar = st.progress(0)
|
| 761 |
+
status_text = st.empty()
|
| 762 |
+
|
| 763 |
+
try:
|
| 764 |
+
from preprocessing.pipeline import PreprocessingPipeline, PreprocessingConfig
|
| 765 |
+
from preprocessing.text_cleaning import TextCleaningConfig
|
| 766 |
+
from preprocessing.tokenization import TokenizationConfig
|
| 767 |
+
from preprocessing.system_prompt import SystemPromptConfig
|
| 768 |
+
from preprocessing.dataset_balancing import BalancingConfig
|
| 769 |
+
from preprocessing.quality_filters import QualityFilterConfig
|
| 770 |
+
from preprocessing.deduplication import DeduplicationConfig
|
| 771 |
+
from preprocessing.train_val_split import SplitConfig
|
| 772 |
+
from preprocessing.output_formatter import OutputFormatConfig, format_dataset, export_jsonl, generate_preview
|
| 773 |
+
from preprocessing.pii_filter import PIIFilterConfig
|
| 774 |
+
from preprocessing.augmentation import AugmentationConfig
|
| 775 |
+
|
| 776 |
+
# Build config from UI values
|
| 777 |
+
config = PreprocessingConfig(
|
| 778 |
+
instruction_col=instruction_col,
|
| 779 |
+
output_col=output_col,
|
| 780 |
+
input_col=input_col,
|
| 781 |
+
label_col=label_col if balance_enabled else None,
|
| 782 |
+
text_cleaning=TextCleaningConfig(
|
| 783 |
+
remove_html=clean_html, remove_urls=clean_urls,
|
| 784 |
+
remove_emojis=clean_emojis, normalize_whitespace=clean_whitespace,
|
| 785 |
+
lowercase=clean_lowercase, remove_special_chars=clean_special,
|
| 786 |
+
strip_extra_linebreaks=clean_linebreaks,
|
| 787 |
+
),
|
| 788 |
+
tokenization=TokenizationConfig(
|
| 789 |
+
tokenizer_name="tiktoken" if tokenizer_choice == "tiktoken" else hf_model_name,
|
| 790 |
+
max_total_tokens=max_total_tokens,
|
| 791 |
+
truncate_long=truncate_long, split_long=split_long,
|
| 792 |
+
split_overlap=split_overlap,
|
| 793 |
+
),
|
| 794 |
+
system_prompt=SystemPromptConfig(
|
| 795 |
+
system_prompt=system_prompt_text,
|
| 796 |
+
prepend_to_all=prepend_system,
|
| 797 |
+
),
|
| 798 |
+
balancing=BalancingConfig(
|
| 799 |
+
enabled=balance_enabled,
|
| 800 |
+
label_column=label_col if balance_enabled else "",
|
| 801 |
+
strategy=balance_strategy if balance_enabled else "none",
|
| 802 |
+
),
|
| 803 |
+
quality_filters=QualityFilterConfig(
|
| 804 |
+
min_word_count=min_words, max_word_count=max_words,
|
| 805 |
+
profanity_filter=profanity_filter,
|
| 806 |
+
language_filter=language_filter,
|
| 807 |
+
allowed_languages=[l.strip() for l in allowed_langs.split(',')],
|
| 808 |
+
remove_low_quality=remove_low_quality,
|
| 809 |
+
),
|
| 810 |
+
deduplication=DeduplicationConfig(
|
| 811 |
+
remove_exact=dedup_exact, remove_semantic=dedup_semantic,
|
| 812 |
+
semantic_threshold=semantic_threshold,
|
| 813 |
+
),
|
| 814 |
+
split=SplitConfig(
|
| 815 |
+
enabled=split_enabled, train_ratio=train_ratio,
|
| 816 |
+
random_seed=int(random_seed), shuffle=shuffle_data,
|
| 817 |
+
),
|
| 818 |
+
output_format=OutputFormatConfig(
|
| 819 |
+
format_type=format_type, custom_schema=custom_schema,
|
| 820 |
+
),
|
| 821 |
+
pii_filter=PIIFilterConfig(
|
| 822 |
+
filter_emails=pii_emails, filter_phones=pii_phones,
|
| 823 |
+
filter_id_numbers=pii_ids, filter_api_keys=pii_keys,
|
| 824 |
+
filter_addresses=pii_addresses,
|
| 825 |
+
),
|
| 826 |
+
augmentation=AugmentationConfig(
|
| 827 |
+
enabled=aug_enabled, paraphrase=aug_paraphrase,
|
| 828 |
+
generate_variations=aug_variations,
|
| 829 |
+
back_translate=aug_backtranslate,
|
| 830 |
+
tone_rewrite=aug_tone,
|
| 831 |
+
augmentation_factor=aug_factor,
|
| 832 |
+
),
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
def progress_cb(stage_name, pct):
|
| 836 |
+
status_text.text(f"⚙️ {stage_name}...")
|
| 837 |
+
progress_bar.progress(min(pct, 100))
|
| 838 |
+
|
| 839 |
+
pipeline = PreprocessingPipeline(config)
|
| 840 |
+
train_df, val_df, logs = pipeline.run(df, progress_callback=progress_cb)
|
| 841 |
+
|
| 842 |
+
# Format output
|
| 843 |
+
sys_prompt = system_prompt_text if prepend_system else ""
|
| 844 |
+
formatted_data = format_dataset(
|
| 845 |
+
train_df, config.output_format,
|
| 846 |
+
system_prompt=sys_prompt,
|
| 847 |
+
instruction_col=instruction_col,
|
| 848 |
+
output_col=output_col,
|
| 849 |
+
input_col=input_col,
|
| 850 |
+
label_col=label_col if balance_enabled else None,
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
# Export
|
| 854 |
+
output_dir = Path("./output/processed_data")
|
| 855 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 856 |
+
train_path = export_jsonl(formatted_data, str(output_dir / f"{goal}_train.jsonl"))
|
| 857 |
+
|
| 858 |
+
val_path = None
|
| 859 |
+
if len(val_df) > 0:
|
| 860 |
+
val_formatted = format_dataset(
|
| 861 |
+
val_df, config.output_format,
|
| 862 |
+
system_prompt=sys_prompt,
|
| 863 |
+
instruction_col=instruction_col,
|
| 864 |
+
output_col=output_col,
|
| 865 |
+
input_col=input_col,
|
| 866 |
+
label_col=label_col if balance_enabled else None,
|
| 867 |
+
)
|
| 868 |
+
val_path = export_jsonl(val_formatted, str(output_dir / f"{goal}_val.jsonl"))
|
| 869 |
+
|
| 870 |
+
progress_bar.progress(100)
|
| 871 |
+
status_text.text("✅ Pipeline complete!")
|
| 872 |
+
|
| 873 |
+
st.session_state.processed_data_path = train_path
|
| 874 |
+
st.session_state.pipeline_status['data'] = 'complete'
|
| 875 |
+
|
| 876 |
+
# ── Results ──
|
| 877 |
+
st.success(f"✅ Training data saved to: `{train_path}`")
|
| 878 |
+
if val_path:
|
| 879 |
+
st.success(f"✅ Validation data saved to: `{val_path}`")
|
| 880 |
+
|
| 881 |
+
# Stats
|
| 882 |
+
rc1, rc2, rc3, rc4 = st.columns(4)
|
| 883 |
+
with rc1:
|
| 884 |
+
st.metric("Original Rows", f"{len(df):,}")
|
| 885 |
+
with rc2:
|
| 886 |
+
st.metric("Train Samples", f"{len(train_df):,}")
|
| 887 |
+
with rc3:
|
| 888 |
+
st.metric("Val Samples", f"{len(val_df):,}")
|
| 889 |
+
with rc4:
|
| 890 |
+
removed = len(df) - len(train_df) - len(val_df)
|
| 891 |
+
st.metric("Removed", f"{max(0, removed):,}")
|
| 892 |
+
|
| 893 |
+
# ── Pipeline Logs ──
|
| 894 |
+
st.markdown("### 📋 Pipeline Logs")
|
| 895 |
+
log_data = []
|
| 896 |
+
for log in logs:
|
| 897 |
+
log_data.append({
|
| 898 |
+
'Stage': log.stage,
|
| 899 |
+
'Description': log.description,
|
| 900 |
+
'Rows Before': log.rows_before,
|
| 901 |
+
'Rows After': log.rows_after,
|
| 902 |
+
'Delta': log.rows_delta,
|
| 903 |
+
'Time (ms)': log.duration_ms,
|
| 904 |
+
})
|
| 905 |
+
st.dataframe(pd.DataFrame(log_data), use_container_width=True)
|
| 906 |
+
|
| 907 |
+
# ── Preview ──
|
| 908 |
+
st.markdown("### 👁️ Output Preview")
|
| 909 |
+
preview_json = generate_preview(formatted_data, n=3)
|
| 910 |
+
st.code(preview_json, language="json")
|
| 911 |
+
|
| 912 |
+
# ── Download ──
|
| 913 |
+
st.markdown("### 📥 Download")
|
| 914 |
+
dl1, dl2 = st.columns(2)
|
| 915 |
+
with dl1:
|
| 916 |
+
with open(train_path, 'r', encoding='utf-8') as f:
|
| 917 |
+
st.download_button("⬇️ Download Train JSONL", f.read(),
|
| 918 |
+
file_name=f"{goal}_train.jsonl", mime="application/jsonl")
|
| 919 |
+
with dl2:
|
| 920 |
+
if val_path and Path(val_path).exists():
|
| 921 |
+
with open(val_path, 'r', encoding='utf-8') as f:
|
| 922 |
+
st.download_button("⬇️ Download Val JSONL", f.read(),
|
| 923 |
+
file_name=f"{goal}_val.jsonl", mime="application/jsonl")
|
| 924 |
+
|
| 925 |
+
except Exception as e:
|
| 926 |
+
st.session_state.pipeline_status['data'] = 'error'
|
| 927 |
+
st.error(f"❌ Pipeline Error: {str(e)}")
|
| 928 |
+
import traceback
|
| 929 |
+
st.code(traceback.format_exc())
|
| 930 |
+
|
| 931 |
+
# Show previously processed data
|
| 932 |
+
if st.session_state.processed_data_path:
|
| 933 |
+
st.markdown("---")
|
| 934 |
+
st.markdown("### 📂 Last Processed Data")
|
| 935 |
+
try:
|
| 936 |
+
processed_path = Path(st.session_state.processed_data_path)
|
| 937 |
+
if processed_path.exists():
|
| 938 |
+
with open(processed_path, encoding='utf-8') as f:
|
| 939 |
+
samples = [json.loads(line) for line in f.readlines()[:5]]
|
| 940 |
+
for i, sample in enumerate(samples):
|
| 941 |
+
with st.expander(f"Sample {i+1}"):
|
| 942 |
+
st.json(sample)
|
| 943 |
+
except Exception as e:
|
| 944 |
+
st.warning(f"Could not load preview: {e}")
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
# ============================================================================
|
| 948 |
+
# PAGE: TRAINING
|
| 949 |
+
# ============================================================================
|
| 950 |
+
def render_training():
|
| 951 |
+
st.markdown('<p class="gradient-header">🚀 Model Training</p>', unsafe_allow_html=True)
|
| 952 |
+
|
| 953 |
+
# Check prerequisites
|
| 954 |
+
if st.session_state.processed_data_path is None:
|
| 955 |
+
st.warning("⚠️ Please process your data first!")
|
| 956 |
+
if st.button("🧹 Go to Processing"):
|
| 957 |
+
st.session_state.current_page = 'process'
|
| 958 |
+
st.rerun()
|
| 959 |
+
return
|
| 960 |
+
|
| 961 |
+
# ── GPU Detection ──
|
| 962 |
+
try:
|
| 963 |
+
import torch
|
| 964 |
+
has_gpu = torch.cuda.is_available()
|
| 965 |
+
if has_gpu:
|
| 966 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 967 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 968 |
+
st.success(f"✅ GPU Available: **{gpu_name}** ({gpu_memory:.1f} GB)")
|
| 969 |
+
except Exception:
|
| 970 |
+
has_gpu = False
|
| 971 |
+
|
| 972 |
+
# ── Download Preprocessed Data (always available) ──
|
| 973 |
+
st.markdown("### 📥 Preprocessed Training Data")
|
| 974 |
+
processed_path = Path(st.session_state.processed_data_path)
|
| 975 |
+
if processed_path.exists():
|
| 976 |
+
with open(processed_path, 'r', encoding='utf-8') as f:
|
| 977 |
+
processed_content = f.read()
|
| 978 |
+
dl1, dl2 = st.columns(2)
|
| 979 |
+
with dl1:
|
| 980 |
+
st.download_button("⬇️ Download Training JSONL", processed_content,
|
| 981 |
+
file_name=processed_path.name, mime="application/jsonl")
|
| 982 |
+
with dl2:
|
| 983 |
+
# Check for validation file
|
| 984 |
+
val_path = processed_path.parent / processed_path.name.replace('_train', '_val')
|
| 985 |
+
if val_path.exists():
|
| 986 |
+
with open(val_path, 'r', encoding='utf-8') as f:
|
| 987 |
+
st.download_button("⬇️ Download Validation JSONL", f.read(),
|
| 988 |
+
file_name=val_path.name, mime="application/jsonl")
|
| 989 |
+
try:
|
| 990 |
+
sample_count = sum(1 for _ in processed_content.split('\n') if _.strip())
|
| 991 |
+
except Exception:
|
| 992 |
+
sample_count = 0
|
| 993 |
+
st.info(f"📊 Dataset: **{sample_count:,}** samples ready for training")
|
| 994 |
+
else:
|
| 995 |
+
st.warning("Processed data file not found.")
|
| 996 |
+
|
| 997 |
+
st.markdown("---")
|
| 998 |
+
|
| 999 |
+
# ====================================================================
|
| 1000 |
+
# TWO PATHS: GPU Training OR Colab Notebook
|
| 1001 |
+
# ====================================================================
|
| 1002 |
+
if has_gpu:
|
| 1003 |
+
training_mode = "gpu"
|
| 1004 |
+
else:
|
| 1005 |
+
training_mode = st.radio("🖥️ Select Training Mode", [
|
| 1006 |
+
"☁️ Use Google Colab (Recommended – Free GPU)",
|
| 1007 |
+
"📤 Upload Fine-Tuned Model (Already trained externally)"
|
| 1008 |
+
], help="No GPU detected on this machine. Choose how to proceed.")
|
| 1009 |
+
|
| 1010 |
+
# ====================================================================
|
| 1011 |
+
# PATH A: GPU Training (local)
|
| 1012 |
+
# ====================================================================
|
| 1013 |
+
if training_mode == "gpu":
|
| 1014 |
+
st.markdown("### ⚙️ Training Configuration")
|
| 1015 |
+
|
| 1016 |
+
col1, col2 = st.columns(2)
|
| 1017 |
+
with col1:
|
| 1018 |
+
model_source = st.radio("Model Source", ["Preset Models", "Custom HuggingFace Model"])
|
| 1019 |
+
if model_source == "Preset Models":
|
| 1020 |
+
base_model = st.selectbox("Base Model", [
|
| 1021 |
+
"unsloth/llama-3-8b-bnb-4bit",
|
| 1022 |
+
"unsloth/llama-3-70b-bnb-4bit",
|
| 1023 |
+
"unsloth/mistral-7b-bnb-4bit",
|
| 1024 |
+
"unsloth/gemma-7b-bnb-4bit",
|
| 1025 |
+
])
|
| 1026 |
+
else:
|
| 1027 |
+
base_model = st.text_input("HuggingFace Model ID",
|
| 1028 |
+
value="unsloth/llama-3-8b-bnb-4bit",
|
| 1029 |
+
help="Enter any HuggingFace model ID, e.g. 'meta-llama/Llama-3-8b', 'mistralai/Mistral-7B-v0.1'")
|
| 1030 |
+
max_seq_length = st.slider("Max Sequence Length", 512, 4096, 2048)
|
| 1031 |
+
|
| 1032 |
+
with col2:
|
| 1033 |
+
dataset_size = sample_count if sample_count > 0 else 1000
|
| 1034 |
+
if dataset_size < 1000:
|
| 1035 |
+
auto_rank, auto_alpha, auto_lr, auto_epochs = 8, 16, 2e-4, 5
|
| 1036 |
+
size_category = "Small"
|
| 1037 |
+
elif dataset_size < 10000:
|
| 1038 |
+
auto_rank, auto_alpha, auto_lr, auto_epochs = 16, 32, 1e-4, 3
|
| 1039 |
+
size_category = "Medium"
|
| 1040 |
+
else:
|
| 1041 |
+
auto_rank, auto_alpha, auto_lr, auto_epochs = 32, 64, 5e-5, 2
|
| 1042 |
+
size_category = "Large"
|
| 1043 |
+
st.success(f"Auto-configured for **{size_category}** dataset ({dataset_size:,} samples)")
|
| 1044 |
+
|
| 1045 |
+
st.markdown("---")
|
| 1046 |
+
|
| 1047 |
+
with st.expander("🔧 Advanced Hyperparameters"):
|
| 1048 |
+
hc1, hc2, hc3 = st.columns(3)
|
| 1049 |
+
with hc1:
|
| 1050 |
+
lora_rank = st.slider("LoRA Rank", 4, 64, auto_rank)
|
| 1051 |
+
lora_alpha = st.slider("LoRA Alpha", 8, 128, auto_alpha)
|
| 1052 |
+
with hc2:
|
| 1053 |
+
learning_rate = st.select_slider("Learning Rate",
|
| 1054 |
+
options=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4], value=auto_lr)
|
| 1055 |
+
num_epochs = st.slider("Epochs", 1, 10, auto_epochs)
|
| 1056 |
+
with hc3:
|
| 1057 |
+
batch_size = st.slider("Batch Size", 1, 16, 4)
|
| 1058 |
+
gradient_accumulation = st.slider("Gradient Accumulation", 1, 8, 4)
|
| 1059 |
+
|
| 1060 |
+
st.markdown("---")
|
| 1061 |
+
|
| 1062 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
| 1063 |
+
with col2:
|
| 1064 |
+
if st.button("🚀 Start Training", type="primary", use_container_width=True):
|
| 1065 |
+
st.session_state.pipeline_status['training'] = 'running'
|
| 1066 |
+
with st.spinner("Training in progress..."):
|
| 1067 |
+
progress_bar = st.progress(0)
|
| 1068 |
+
status_text = st.empty()
|
| 1069 |
+
try:
|
| 1070 |
+
from agents.training_pilot import TrainingPilot, HyperParams
|
| 1071 |
+
status_text.text("📦 Loading model...")
|
| 1072 |
+
progress_bar.progress(10)
|
| 1073 |
+
pilot = TrainingPilot(
|
| 1074 |
+
base_model=base_model,
|
| 1075 |
+
max_seq_length=max_seq_length,
|
| 1076 |
+
output_dir="./output/models"
|
| 1077 |
+
)
|
| 1078 |
+
status_text.text("🚀 Training...")
|
| 1079 |
+
progress_bar.progress(30)
|
| 1080 |
+
result = pilot.run(
|
| 1081 |
+
data_path=st.session_state.processed_data_path,
|
| 1082 |
+
output_name=st.session_state.training_goal
|
| 1083 |
+
)
|
| 1084 |
+
progress_bar.progress(100)
|
| 1085 |
+
status_text.text("✅ Training complete!")
|
| 1086 |
+
st.session_state.model_path = result.model_path
|
| 1087 |
+
st.session_state.pipeline_status['training'] = 'complete'
|
| 1088 |
+
st.success(f"✅ Model saved to: `{result.model_path}`")
|
| 1089 |
+
rc1, rc2, rc3 = st.columns(3)
|
| 1090 |
+
with rc1:
|
| 1091 |
+
st.metric("Final Loss", f"{result.final_loss:.4f}")
|
| 1092 |
+
with rc2:
|
| 1093 |
+
st.metric("Training Time", f"{result.training_time:.1f}s")
|
| 1094 |
+
with rc3:
|
| 1095 |
+
st.metric("Total Steps", result.num_steps)
|
| 1096 |
+
except Exception as e:
|
| 1097 |
+
st.session_state.pipeline_status['training'] = 'error'
|
| 1098 |
+
st.error(f"❌ Training failed: {str(e)}")
|
| 1099 |
+
import traceback
|
| 1100 |
+
st.code(traceback.format_exc())
|
| 1101 |
+
|
| 1102 |
+
# ====================================================================
|
| 1103 |
+
# PATH B: Google Colab Notebook
|
| 1104 |
+
# ====================================================================
|
| 1105 |
+
elif "Colab" in training_mode:
|
| 1106 |
+
st.markdown("### ☁️ Train on Google Colab (Free GPU)")
|
| 1107 |
+
st.markdown("""
|
| 1108 |
+
Since no GPU was detected on this machine, you can fine-tune your model on Google Colab with a free GPU.
|
| 1109 |
+
Follow these steps:
|
| 1110 |
+
""")
|
| 1111 |
+
|
| 1112 |
+
st.markdown("""
|
| 1113 |
+
**Step 1:** Download your preprocessed training data (above) ⬆️
|
| 1114 |
+
|
| 1115 |
+
**Step 2:** Download or copy the Colab notebook below
|
| 1116 |
+
|
| 1117 |
+
**Step 3:** Open [Google Colab](https://colab.research.google.com/) → Upload the notebook
|
| 1118 |
+
|
| 1119 |
+
**Step 4:** Upload your training JSONL to Colab's file browser
|
| 1120 |
+
|
| 1121 |
+
**Step 5:** Run all cells → Download the fine-tuned model
|
| 1122 |
+
|
| 1123 |
+
**Step 6:** Come back here → Upload your fine-tuned model results for evaluation
|
| 1124 |
+
""")
|
| 1125 |
+
|
| 1126 |
+
# Show / Download Colab notebook
|
| 1127 |
+
notebook_path = Path("./Auto_FineTune_Ops_Colab.ipynb")
|
| 1128 |
+
if notebook_path.exists():
|
| 1129 |
+
with open(notebook_path, 'r', encoding='utf-8') as f:
|
| 1130 |
+
notebook_content = f.read()
|
| 1131 |
+
|
| 1132 |
+
st.download_button("📓 Download Colab Notebook (.ipynb)", notebook_content,
|
| 1133 |
+
file_name="Auto_FineTune_Ops_Colab.ipynb", mime="application/json",
|
| 1134 |
+
type="primary", use_container_width=True)
|
| 1135 |
+
|
| 1136 |
+
with st.expander("👁️ View Notebook Code", expanded=False):
|
| 1137 |
+
try:
|
| 1138 |
+
import json as json_mod
|
| 1139 |
+
nb = json_mod.loads(notebook_content)
|
| 1140 |
+
for cell in nb.get('cells', []):
|
| 1141 |
+
if cell.get('cell_type') == 'code':
|
| 1142 |
+
source = ''.join(cell.get('source', []))
|
| 1143 |
+
if source.strip():
|
| 1144 |
+
st.code(source, language='python')
|
| 1145 |
+
elif cell.get('cell_type') == 'markdown':
|
| 1146 |
+
source = ''.join(cell.get('source', []))
|
| 1147 |
+
st.markdown(source)
|
| 1148 |
+
except Exception:
|
| 1149 |
+
st.code(notebook_content[:5000], language='json')
|
| 1150 |
+
else:
|
| 1151 |
+
st.warning("⚠️ Colab notebook not found at `Auto_FineTune_Ops_Colab.ipynb`")
|
| 1152 |
+
|
| 1153 |
+
st.markdown("---")
|
| 1154 |
+
st.markdown("### 📤 After Training on Colab")
|
| 1155 |
+
st.info("Once you've finished training on Colab, download your fine-tuned model outputs and upload them below for evaluation.")
|
| 1156 |
+
|
| 1157 |
+
# ====================================================================
|
| 1158 |
+
# PATH C: Upload Fine-Tuned Model / Results
|
| 1159 |
+
# ====================================================================
|
| 1160 |
+
else:
|
| 1161 |
+
st.markdown("### 📤 Upload Fine-Tuned Model Results")
|
| 1162 |
+
st.markdown("Upload outputs from your externally trained model for evaluation.")
|
| 1163 |
+
|
| 1164 |
+
# ── Upload Fine-Tuned Results (always shown at bottom) ──
|
| 1165 |
+
st.markdown("---")
|
| 1166 |
+
st.markdown("### 📦 Upload Fine-Tuned Results for Evaluation")
|
| 1167 |
+
st.caption("If you trained on Colab or another machine, upload your model outputs here.")
|
| 1168 |
+
|
| 1169 |
+
upload_tab1, upload_tab2 = st.tabs(["📊 Upload Evaluation Results (JSONL)", "📁 Upload Model Folder Path"])
|
| 1170 |
+
|
| 1171 |
+
with upload_tab1:
|
| 1172 |
+
ft_file = st.file_uploader("Upload fine-tuned model outputs (JSONL with predictions)",
|
| 1173 |
+
type=['jsonl', 'json'], key="ft_results_upload",
|
| 1174 |
+
help="JSONL file with model predictions/outputs from your fine-tuned model")
|
| 1175 |
+
if ft_file:
|
| 1176 |
+
try:
|
| 1177 |
+
ft_df = pd.read_json(ft_file, lines=ft_file.name.endswith('.jsonl'))
|
| 1178 |
+
st.success(f"✅ Loaded **{len(ft_df):,}** evaluation samples")
|
| 1179 |
+
st.dataframe(ft_df.head(5), use_container_width=True)
|
| 1180 |
+
|
| 1181 |
+
# Save for evaluation
|
| 1182 |
+
eval_output = Path("./output/eval_results")
|
| 1183 |
+
eval_output.mkdir(parents=True, exist_ok=True)
|
| 1184 |
+
eval_path = eval_output / f"finetuned_outputs_{ft_file.name}"
|
| 1185 |
+
ft_df.to_json(eval_path, orient='records', lines=True)
|
| 1186 |
+
|
| 1187 |
+
st.session_state.model_path = str(eval_path)
|
| 1188 |
+
st.session_state.pipeline_status['training'] = 'complete'
|
| 1189 |
+
st.success(f"✅ Results saved! You can now proceed to **Evaluation** page.")
|
| 1190 |
+
|
| 1191 |
+
if st.button("⚖️ Go to Evaluation"):
|
| 1192 |
+
st.session_state.current_page = 'evaluation'
|
| 1193 |
+
st.rerun()
|
| 1194 |
+
except Exception as e:
|
| 1195 |
+
st.error(f"Error loading file: {e}")
|
| 1196 |
+
|
| 1197 |
+
with upload_tab2:
|
| 1198 |
+
model_folder = st.text_input("Model Folder Path",
|
| 1199 |
+
placeholder="e.g., ./output/models/my_finetuned_model or /path/to/model",
|
| 1200 |
+
help="Local path to the fine-tuned model directory (LoRA adapter or full model)")
|
| 1201 |
+
if model_folder and st.button("✅ Set Model Path"):
|
| 1202 |
+
if Path(model_folder).exists():
|
| 1203 |
+
st.session_state.model_path = model_folder
|
| 1204 |
+
st.session_state.pipeline_status['training'] = 'complete'
|
| 1205 |
+
st.success(f"✅ Model path set to: `{model_folder}`")
|
| 1206 |
+
else:
|
| 1207 |
+
st.error(f"❌ Path not found: `{model_folder}`")
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
# ============================================================================
|
| 1211 |
+
# PAGE: EVALUATION
|
| 1212 |
+
# ============================================================================
|
| 1213 |
+
def render_evaluation():
|
| 1214 |
+
st.markdown('<p class="gradient-header">⚖️ Model Evaluation</p>', unsafe_allow_html=True)
|
| 1215 |
+
|
| 1216 |
+
# ── Judge Provider Selection ──
|
| 1217 |
+
st.markdown("### 🤖 Select AI Judge Provider")
|
| 1218 |
+
st.caption("Choose which LLM provider to use as the evaluation judge. You can use any model you have API access to.")
|
| 1219 |
+
|
| 1220 |
+
judge_provider = st.selectbox("AI Provider", [
|
| 1221 |
+
"OpenAI (GPT-4o, GPT-4-turbo, etc.)",
|
| 1222 |
+
"Anthropic (Claude 3.5, Claude 3 Opus, etc.)",
|
| 1223 |
+
"Google Gemini (Gemini Pro, Gemini 1.5, etc.)",
|
| 1224 |
+
"Groq (Llama, Mixtral, Gemma, etc.)",
|
| 1225 |
+
"Custom OpenAI-Compatible Endpoint"
|
| 1226 |
+
], help="Select the AI provider whose model will act as the judge for evaluating your fine-tuned model.")
|
| 1227 |
+
|
| 1228 |
+
st.markdown("---")
|
| 1229 |
+
st.markdown("### 🔑 API Configuration")
|
| 1230 |
+
|
| 1231 |
+
if "OpenAI" in judge_provider:
|
| 1232 |
+
col1, col2 = st.columns(2)
|
| 1233 |
+
with col1:
|
| 1234 |
+
openai_key = st.text_input("OpenAI API Key", type="password",
|
| 1235 |
+
help="Your OpenAI API key (starts with sk-)")
|
| 1236 |
+
if openai_key:
|
| 1237 |
+
os.environ["OPENAI_API_KEY"] = openai_key
|
| 1238 |
+
with col2:
|
| 1239 |
+
judge_model = st.selectbox("Judge Model", [
|
| 1240 |
+
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"
|
| 1241 |
+
])
|
| 1242 |
+
|
| 1243 |
+
elif "Anthropic" in judge_provider:
|
| 1244 |
+
col1, col2 = st.columns(2)
|
| 1245 |
+
with col1:
|
| 1246 |
+
anthropic_key = st.text_input("Anthropic API Key", type="password",
|
| 1247 |
+
help="Your Anthropic API key")
|
| 1248 |
+
if anthropic_key:
|
| 1249 |
+
os.environ["ANTHROPIC_API_KEY"] = anthropic_key
|
| 1250 |
+
with col2:
|
| 1251 |
+
judge_model = st.selectbox("Judge Model", [
|
| 1252 |
+
"claude-3-5-sonnet-20241022", "claude-3-opus-20240229",
|
| 1253 |
+
"claude-3-sonnet-20240229", "claude-3-haiku-20240307"
|
| 1254 |
+
])
|
| 1255 |
+
|
| 1256 |
+
elif "Gemini" in judge_provider:
|
| 1257 |
+
col1, col2 = st.columns(2)
|
| 1258 |
+
with col1:
|
| 1259 |
+
gemini_key = st.text_input("Google AI API Key", type="password",
|
| 1260 |
+
help="Your Google AI Studio API key for Gemini models")
|
| 1261 |
+
if gemini_key:
|
| 1262 |
+
os.environ["GOOGLE_API_KEY"] = gemini_key
|
| 1263 |
+
with col2:
|
| 1264 |
+
judge_model = st.selectbox("Judge Model", [
|
| 1265 |
+
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-pro"
|
| 1266 |
+
])
|
| 1267 |
+
|
| 1268 |
+
elif "Groq" in judge_provider:
|
| 1269 |
+
col1, col2 = st.columns(2)
|
| 1270 |
+
with col1:
|
| 1271 |
+
groq_key = st.text_input("Groq API Key", type="password",
|
| 1272 |
+
help="Your Groq API key for fast inference")
|
| 1273 |
+
if groq_key:
|
| 1274 |
+
os.environ["GROQ_API_KEY"] = groq_key
|
| 1275 |
+
with col2:
|
| 1276 |
+
judge_model = st.selectbox("Judge Model", [
|
| 1277 |
+
"llama-3.1-70b-versatile", "llama-3.1-8b-instant",
|
| 1278 |
+
"mixtral-8x7b-32768", "gemma2-9b-it"
|
| 1279 |
+
])
|
| 1280 |
+
|
| 1281 |
+
else: # Custom endpoint
|
| 1282 |
+
col1, col2 = st.columns(2)
|
| 1283 |
+
with col1:
|
| 1284 |
+
custom_base_url = st.text_input("API Base URL",
|
| 1285 |
+
placeholder="https://api.your-provider.com/v1",
|
| 1286 |
+
help="OpenAI-compatible API endpoint (e.g., vLLM, Ollama, LM Studio)")
|
| 1287 |
+
custom_api_key = st.text_input("API Key", type="password",
|
| 1288 |
+
help="API key for the custom endpoint (use 'none' for local servers)")
|
| 1289 |
+
if custom_api_key:
|
| 1290 |
+
os.environ["OPENAI_API_KEY"] = custom_api_key
|
| 1291 |
+
if custom_base_url:
|
| 1292 |
+
os.environ["OPENAI_BASE_URL"] = custom_base_url
|
| 1293 |
+
with col2:
|
| 1294 |
+
judge_model = st.text_input("Model Name",
|
| 1295 |
+
placeholder="e.g., my-model, llama-3-8b",
|
| 1296 |
+
help="Model identifier used by your custom endpoint")
|
| 1297 |
+
|
| 1298 |
+
st.markdown("---")
|
| 1299 |
+
|
| 1300 |
+
# ── Model / Results to Evaluate ──
|
| 1301 |
+
st.markdown("### 📊 Evaluation Data")
|
| 1302 |
+
|
| 1303 |
+
if st.session_state.model_path:
|
| 1304 |
+
st.info(f"📦 Model / Results: `{st.session_state.model_path}`")
|
| 1305 |
+
else:
|
| 1306 |
+
st.warning("⚠️ No trained model or uploaded results found. You can upload evaluation data below or train a model first.")
|
| 1307 |
+
|
| 1308 |
+
# Upload evaluation data
|
| 1309 |
+
eval_upload = st.file_uploader("Upload evaluation data (JSONL with instruction + model output)",
|
| 1310 |
+
type=['jsonl', 'json'], key="eval_data_upload",
|
| 1311 |
+
help="Upload a JSONL file containing instruction-response pairs to evaluate")
|
| 1312 |
+
if eval_upload:
|
| 1313 |
+
try:
|
| 1314 |
+
eval_df = pd.read_json(eval_upload, lines=eval_upload.name.endswith('.jsonl'))
|
| 1315 |
+
st.success(f"✅ Loaded **{len(eval_df):,}** samples for evaluation")
|
| 1316 |
+
st.dataframe(eval_df.head(5), use_container_width=True)
|
| 1317 |
+
st.session_state['eval_data'] = eval_df
|
| 1318 |
+
except Exception as e:
|
| 1319 |
+
st.error(f"Error loading evaluation data: {e}")
|
| 1320 |
+
|
| 1321 |
+
st.markdown("---")
|
| 1322 |
+
|
| 1323 |
+
# ── Demo Charts ──
|
| 1324 |
+
st.markdown("### 📈 Evaluation Results")
|
| 1325 |
+
|
| 1326 |
+
col1, col2 = st.columns(2)
|
| 1327 |
+
with col1:
|
| 1328 |
+
fig = go.Figure(data=[go.Pie(
|
| 1329 |
+
values=[72, 18, 10],
|
| 1330 |
+
labels=['Fine-tuned Wins', 'Base Model Wins', 'Ties'],
|
| 1331 |
+
hole=0.6,
|
| 1332 |
+
marker_colors=['#6366f1', '#ef4444', '#94a3b8']
|
| 1333 |
+
)])
|
| 1334 |
+
fig.update_layout(
|
| 1335 |
+
title="Win Rate Distribution",
|
| 1336 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 1337 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
| 1338 |
+
font_color='#e2e8f0',
|
| 1339 |
+
showlegend=True
|
| 1340 |
+
)
|
| 1341 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 1342 |
+
|
| 1343 |
+
with col2:
|
| 1344 |
+
fig = go.Figure(data=[
|
| 1345 |
+
go.Bar(name='Base Model', x=['Accuracy', 'Helpfulness', 'Clarity', 'Relevance'], y=[6.2, 5.8, 6.5, 6.0], marker_color='#ef4444'),
|
| 1346 |
+
go.Bar(name='Fine-tuned', x=['Accuracy', 'Helpfulness', 'Clarity', 'Relevance'], y=[7.8, 8.1, 7.5, 8.2], marker_color='#6366f1')
|
| 1347 |
+
])
|
| 1348 |
+
fig.update_layout(
|
| 1349 |
+
title="Score Comparison by Category",
|
| 1350 |
+
barmode='group',
|
| 1351 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 1352 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
| 1353 |
+
font_color='#e2e8f0',
|
| 1354 |
+
yaxis_title="Score (1-10)"
|
| 1355 |
+
)
|
| 1356 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 1357 |
+
|
| 1358 |
+
# Summary metrics
|
| 1359 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 1360 |
+
with col1:
|
| 1361 |
+
st.metric("Win Rate", "72%", "+22%")
|
| 1362 |
+
with col2:
|
| 1363 |
+
st.metric("Base Avg Score", "6.4/10")
|
| 1364 |
+
with col3:
|
| 1365 |
+
st.metric("Fine-tuned Avg", "7.8/10", "+1.4")
|
| 1366 |
+
with col4:
|
| 1367 |
+
st.metric("Comparisons", "50")
|
| 1368 |
+
|
| 1369 |
+
st.markdown("---")
|
| 1370 |
+
|
| 1371 |
+
# Run evaluation
|
| 1372 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
| 1373 |
+
with col2:
|
| 1374 |
+
if st.button("🏃 Run Full Evaluation", type="primary", use_container_width=True):
|
| 1375 |
+
has_key = any([
|
| 1376 |
+
os.environ.get("OPENAI_API_KEY"),
|
| 1377 |
+
os.environ.get("ANTHROPIC_API_KEY"),
|
| 1378 |
+
os.environ.get("GOOGLE_API_KEY"),
|
| 1379 |
+
os.environ.get("GROQ_API_KEY"),
|
| 1380 |
+
])
|
| 1381 |
+
if not has_key:
|
| 1382 |
+
st.error("❌ Please provide an API key for your selected judge provider.")
|
| 1383 |
+
elif not st.session_state.model_path and not st.session_state.get('eval_data') is not None:
|
| 1384 |
+
st.error("❌ Please either train a model, upload fine-tuned results, or upload evaluation data.")
|
| 1385 |
+
else:
|
| 1386 |
+
st.info(f"🏃 Starting evaluation with **{judge_model}** as judge...")
|
| 1387 |
+
st.warning("⏳ Full evaluation pipeline integration coming soon. Demo results shown above.")
|
| 1388 |
+
|
| 1389 |
+
|
| 1390 |
+
# ============================================================================
|
| 1391 |
+
# PAGE: DEPLOYMENT
|
| 1392 |
+
# ============================================================================
|
| 1393 |
+
def render_deploy():
|
| 1394 |
+
st.markdown('<p class="gradient-header">🌐 Model Deployment</p>', unsafe_allow_html=True)
|
| 1395 |
+
|
| 1396 |
+
# Model selection
|
| 1397 |
+
st.markdown("### 📦 Select Model")
|
| 1398 |
+
|
| 1399 |
+
models_dir = Path("./output/models")
|
| 1400 |
+
if models_dir.exists():
|
| 1401 |
+
models = [d.name for d in models_dir.iterdir() if d.is_dir()]
|
| 1402 |
+
if models:
|
| 1403 |
+
selected_model = st.selectbox("Trained Models", models)
|
| 1404 |
+
model_path = models_dir / selected_model
|
| 1405 |
+
st.info(f"📂 Model path: `{model_path}`")
|
| 1406 |
+
else:
|
| 1407 |
+
st.warning("No trained models found.")
|
| 1408 |
+
selected_model = None
|
| 1409 |
+
else:
|
| 1410 |
+
st.warning("Models directory not found.")
|
| 1411 |
+
selected_model = None
|
| 1412 |
+
|
| 1413 |
+
st.markdown("---")
|
| 1414 |
+
|
| 1415 |
+
# Deployment options
|
| 1416 |
+
st.markdown("### 🚀 Deployment Options")
|
| 1417 |
+
|
| 1418 |
+
col1, col2 = st.columns(2)
|
| 1419 |
+
|
| 1420 |
+
with col1:
|
| 1421 |
+
st.markdown("""
|
| 1422 |
+
<div class="info-card">
|
| 1423 |
+
<h4>🖥️ Local FastAPI Server</h4>
|
| 1424 |
+
<p>Deploy as a REST API on your local machine.</p>
|
| 1425 |
+
</div>
|
| 1426 |
+
""", unsafe_allow_html=True)
|
| 1427 |
+
|
| 1428 |
+
port = st.number_input("Port", value=8000, min_value=1000, max_value=65535)
|
| 1429 |
+
|
| 1430 |
+
if st.button("🚀 Start Server", disabled=not selected_model):
|
| 1431 |
+
st.code(f"python scripts/deploy.py --model ./output/models/{selected_model} --port {port}")
|
| 1432 |
+
st.info("Run the command above in your terminal to start the server.")
|
| 1433 |
+
|
| 1434 |
+
with col2:
|
| 1435 |
+
st.markdown("""
|
| 1436 |
+
<div class="info-card">
|
| 1437 |
+
<h4>☁️ HuggingFace Hub</h4>
|
| 1438 |
+
<p>Push your model to HuggingFace for sharing.</p>
|
| 1439 |
+
</div>
|
| 1440 |
+
""", unsafe_allow_html=True)
|
| 1441 |
+
|
| 1442 |
+
hf_token = st.text_input("HuggingFace Token", type="password")
|
| 1443 |
+
repo_name = st.text_input("Repository Name", value=f"my-finetuned-{selected_model}" if selected_model else "")
|
| 1444 |
+
|
| 1445 |
+
if st.button("☁️ Push to Hub", disabled=not selected_model or not hf_token):
|
| 1446 |
+
st.info("Pushing to HuggingFace Hub...")
|
| 1447 |
+
|
| 1448 |
+
st.markdown("---")
|
| 1449 |
+
|
| 1450 |
+
# API documentation
|
| 1451 |
+
st.markdown("### 📚 API Documentation")
|
| 1452 |
+
|
| 1453 |
+
st.markdown("""
|
| 1454 |
+
Once deployed, your API will have these endpoints:
|
| 1455 |
+
|
| 1456 |
+
| Endpoint | Method | Description |
|
| 1457 |
+
|----------|--------|-------------|
|
| 1458 |
+
| `/` | GET | API info |
|
| 1459 |
+
| `/health` | GET | Health check |
|
| 1460 |
+
| `/generate` | POST | Generate text |
|
| 1461 |
+
| `/generate/batch` | POST | Batch generation |
|
| 1462 |
+
""")
|
| 1463 |
+
|
| 1464 |
+
with st.expander("📝 Example Request"):
|
| 1465 |
+
st.code("""
|
| 1466 |
+
import requests
|
| 1467 |
+
|
| 1468 |
+
response = requests.post("http://localhost:8000/generate", json={
|
| 1469 |
+
"prompt": "What are the symptoms of the common cold?",
|
| 1470 |
+
"max_tokens": 256,
|
| 1471 |
+
"temperature": 0.7
|
| 1472 |
+
})
|
| 1473 |
+
print(response.json()["generated_text"])
|
| 1474 |
+
""", language="python")
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
# ============================================================================
|
| 1478 |
+
# MAIN ROUTER
|
| 1479 |
+
# ============================================================================
|
| 1480 |
+
def main():
|
| 1481 |
+
page = st.session_state.current_page
|
| 1482 |
+
|
| 1483 |
+
if page == 'home':
|
| 1484 |
+
render_home()
|
| 1485 |
+
elif page == 'data':
|
| 1486 |
+
render_data_upload()
|
| 1487 |
+
elif page == 'process':
|
| 1488 |
+
render_processing()
|
| 1489 |
+
elif page == 'training':
|
| 1490 |
+
render_training()
|
| 1491 |
+
elif page == 'evaluation':
|
| 1492 |
+
render_evaluation()
|
| 1493 |
+
elif page == 'deploy':
|
| 1494 |
+
render_deploy()
|
| 1495 |
+
else:
|
| 1496 |
+
render_home()
|
| 1497 |
+
|
| 1498 |
+
|
| 1499 |
+
if __name__ == "__main__":
|
| 1500 |
+
main()
|
configs/default_config.yaml
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto-FineTune-Ops Default Configuration
|
| 2 |
+
# ========================================
|
| 3 |
+
|
| 4 |
+
# Model Configuration
|
| 5 |
+
model:
|
| 6 |
+
base_model: "unsloth/llama-3-8b-bnb-4bit"
|
| 7 |
+
max_seq_length: 2048
|
| 8 |
+
load_in_4bit: true
|
| 9 |
+
dtype: null # Auto-detect
|
| 10 |
+
|
| 11 |
+
# Data Processing
|
| 12 |
+
data:
|
| 13 |
+
min_instruction_length: 10
|
| 14 |
+
max_instruction_length: 2048
|
| 15 |
+
min_response_length: 20
|
| 16 |
+
max_response_length: 4096
|
| 17 |
+
remove_duplicates: true
|
| 18 |
+
quality_threshold: 0.7
|
| 19 |
+
|
| 20 |
+
# Advanced Preprocessing Pipeline
|
| 21 |
+
preprocessing:
|
| 22 |
+
text_cleaning:
|
| 23 |
+
remove_html: true
|
| 24 |
+
remove_urls: true
|
| 25 |
+
remove_emojis: false
|
| 26 |
+
normalize_whitespace: true
|
| 27 |
+
lowercase: false
|
| 28 |
+
remove_special_chars: false
|
| 29 |
+
strip_extra_linebreaks: true
|
| 30 |
+
|
| 31 |
+
tokenization:
|
| 32 |
+
tokenizer_name: "tiktoken"
|
| 33 |
+
tiktoken_encoding: "cl100k_base"
|
| 34 |
+
max_total_tokens: 2048
|
| 35 |
+
truncate_long: false
|
| 36 |
+
split_long: false
|
| 37 |
+
split_overlap: 50
|
| 38 |
+
|
| 39 |
+
system_prompt:
|
| 40 |
+
prompt: "You are a helpful AI assistant."
|
| 41 |
+
prepend_to_all: true
|
| 42 |
+
|
| 43 |
+
balancing:
|
| 44 |
+
enabled: false
|
| 45 |
+
label_column: ""
|
| 46 |
+
strategy: "none"
|
| 47 |
+
|
| 48 |
+
quality_filters:
|
| 49 |
+
min_word_count: 3
|
| 50 |
+
max_word_count: 0
|
| 51 |
+
profanity_filter: false
|
| 52 |
+
language_filter: false
|
| 53 |
+
allowed_languages: ["en"]
|
| 54 |
+
remove_low_quality: true
|
| 55 |
+
min_quality_length: 20
|
| 56 |
+
|
| 57 |
+
deduplication:
|
| 58 |
+
remove_exact: true
|
| 59 |
+
remove_semantic: false
|
| 60 |
+
semantic_threshold: 0.90
|
| 61 |
+
|
| 62 |
+
split:
|
| 63 |
+
enabled: true
|
| 64 |
+
train_ratio: 0.9
|
| 65 |
+
random_seed: 42
|
| 66 |
+
shuffle: true
|
| 67 |
+
|
| 68 |
+
output_format:
|
| 69 |
+
format_type: "openai_chat"
|
| 70 |
+
|
| 71 |
+
pii_filter:
|
| 72 |
+
filter_emails: true
|
| 73 |
+
filter_phones: true
|
| 74 |
+
filter_id_numbers: true
|
| 75 |
+
filter_api_keys: true
|
| 76 |
+
filter_addresses: false
|
| 77 |
+
mask_char: "[REDACTED]"
|
| 78 |
+
|
| 79 |
+
augmentation:
|
| 80 |
+
enabled: false
|
| 81 |
+
paraphrase: false
|
| 82 |
+
generate_variations: false
|
| 83 |
+
back_translate: false
|
| 84 |
+
tone_rewrite: false
|
| 85 |
+
augmentation_factor: 1
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Training Hyperparameters (Auto-configured based on dataset size)
|
| 89 |
+
training:
|
| 90 |
+
# Small datasets (<1K samples)
|
| 91 |
+
small:
|
| 92 |
+
lora_rank: 8
|
| 93 |
+
lora_alpha: 16
|
| 94 |
+
learning_rate: 2.0e-4
|
| 95 |
+
num_epochs: 5
|
| 96 |
+
batch_size: 4
|
| 97 |
+
gradient_accumulation_steps: 4
|
| 98 |
+
|
| 99 |
+
# Medium datasets (1K-10K samples)
|
| 100 |
+
medium:
|
| 101 |
+
lora_rank: 16
|
| 102 |
+
lora_alpha: 32
|
| 103 |
+
learning_rate: 1.0e-4
|
| 104 |
+
num_epochs: 3
|
| 105 |
+
batch_size: 8
|
| 106 |
+
gradient_accumulation_steps: 2
|
| 107 |
+
|
| 108 |
+
# Large datasets (>10K samples)
|
| 109 |
+
large:
|
| 110 |
+
lora_rank: 32
|
| 111 |
+
lora_alpha: 64
|
| 112 |
+
learning_rate: 5.0e-5
|
| 113 |
+
num_epochs: 2
|
| 114 |
+
batch_size: 16
|
| 115 |
+
gradient_accumulation_steps: 1
|
| 116 |
+
|
| 117 |
+
# Common settings
|
| 118 |
+
common:
|
| 119 |
+
warmup_ratio: 0.03
|
| 120 |
+
weight_decay: 0.01
|
| 121 |
+
optimizer: "adamw_8bit"
|
| 122 |
+
lr_scheduler: "cosine"
|
| 123 |
+
gradient_checkpointing: true
|
| 124 |
+
max_grad_norm: 1.0
|
| 125 |
+
|
| 126 |
+
# LoRA Configuration
|
| 127 |
+
lora:
|
| 128 |
+
target_modules:
|
| 129 |
+
- "q_proj"
|
| 130 |
+
- "k_proj"
|
| 131 |
+
- "v_proj"
|
| 132 |
+
- "o_proj"
|
| 133 |
+
- "gate_proj"
|
| 134 |
+
- "up_proj"
|
| 135 |
+
- "down_proj"
|
| 136 |
+
lora_dropout: 0.0
|
| 137 |
+
bias: "none"
|
| 138 |
+
use_rslora: true
|
| 139 |
+
|
| 140 |
+
# Evaluation (TheJudge)
|
| 141 |
+
evaluation:
|
| 142 |
+
judge_model: "gpt-4o" # Options: gpt-4o, claude-3-5-sonnet-20241022
|
| 143 |
+
num_test_samples: 50
|
| 144 |
+
temperature: 0.7
|
| 145 |
+
max_tokens: 512
|
| 146 |
+
|
| 147 |
+
# Deployment
|
| 148 |
+
deployment:
|
| 149 |
+
host: "0.0.0.0"
|
| 150 |
+
port: 8000
|
| 151 |
+
max_batch_size: 8
|
| 152 |
+
inference_max_tokens: 1024
|
| 153 |
+
|
| 154 |
+
# Output Paths
|
| 155 |
+
output:
|
| 156 |
+
base_dir: "./output"
|
| 157 |
+
models_dir: "./output/models"
|
| 158 |
+
logs_dir: "./output/logs"
|
| 159 |
+
reports_dir: "./output/reports"
|
| 160 |
+
data_dir: "./output/processed_data"
|
main.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Auto-FineTune-Ops: The Boss Orchestrator
|
| 3 |
+
==========================================
|
| 4 |
+
One-click autonomous ML fine-tuning pipeline.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python main.py --data ./data.csv --goal "medical_assistant"
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import yaml
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from typing import Optional, Dict, Any
|
| 17 |
+
|
| 18 |
+
from rich.console import Console
|
| 19 |
+
from rich.panel import Panel
|
| 20 |
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
| 21 |
+
from rich.markdown import Markdown
|
| 22 |
+
|
| 23 |
+
# Add project root to path
|
| 24 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 25 |
+
|
| 26 |
+
from agents.data_architect import DataArchitectAgent, CleaningConfig
|
| 27 |
+
from agents.training_pilot import TrainingPilot
|
| 28 |
+
from agents.the_judge import TheJudge, JudgeModel
|
| 29 |
+
|
| 30 |
+
console = Console()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AutoFineTuneOps:
|
| 34 |
+
"""
|
| 35 |
+
The Boss Orchestrator - Runs the complete end-to-end fine-tuning pipeline.
|
| 36 |
+
|
| 37 |
+
Pipeline stages:
|
| 38 |
+
1. Data Preparation (DataArchitectAgent)
|
| 39 |
+
2. Fine-Tuning (TrainingPilot)
|
| 40 |
+
3. Evaluation (TheJudge)
|
| 41 |
+
4. Deployment Ready
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
config_path: Optional[str] = None,
|
| 47 |
+
output_dir: str = "./output"
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Initialize the orchestrator.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
config_path: Path to configuration YAML
|
| 54 |
+
output_dir: Base output directory
|
| 55 |
+
"""
|
| 56 |
+
self.config = self._load_config(config_path)
|
| 57 |
+
self.output_dir = Path(output_dir)
|
| 58 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
# Create subdirectories
|
| 61 |
+
(self.output_dir / "processed_data").mkdir(exist_ok=True)
|
| 62 |
+
(self.output_dir / "models").mkdir(exist_ok=True)
|
| 63 |
+
(self.output_dir / "logs").mkdir(exist_ok=True)
|
| 64 |
+
(self.output_dir / "reports").mkdir(exist_ok=True)
|
| 65 |
+
|
| 66 |
+
# Initialize agents
|
| 67 |
+
self.data_agent = None
|
| 68 |
+
self.training_agent = None
|
| 69 |
+
self.judge_agent = None
|
| 70 |
+
|
| 71 |
+
# Pipeline state
|
| 72 |
+
self.processed_data_path = None
|
| 73 |
+
self.model_path = None
|
| 74 |
+
self.evaluation_result = None
|
| 75 |
+
|
| 76 |
+
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
|
| 77 |
+
"""Load configuration from YAML file."""
|
| 78 |
+
default_config_path = Path(__file__).parent / "configs" / "default_config.yaml"
|
| 79 |
+
|
| 80 |
+
if config_path and Path(config_path).exists():
|
| 81 |
+
with open(config_path, 'r') as f:
|
| 82 |
+
return yaml.safe_load(f)
|
| 83 |
+
elif default_config_path.exists():
|
| 84 |
+
with open(default_config_path, 'r') as f:
|
| 85 |
+
return yaml.safe_load(f)
|
| 86 |
+
return {}
|
| 87 |
+
|
| 88 |
+
def _print_header(self):
|
| 89 |
+
"""Print the main header."""
|
| 90 |
+
header = """
|
| 91 |
+
╔═══════════════════════════════════════════════════════════════╗
|
| 92 |
+
║ ║
|
| 93 |
+
║ 🤖 AUTO-FINETUNE-OPS: AUTONOMOUS ML PIPELINE 🤖 ║
|
| 94 |
+
║ ║
|
| 95 |
+
║ "One-Click Fine-Tuning That Replaces Senior Engineers" ║
|
| 96 |
+
║ ║
|
| 97 |
+
╚═══════════════════════════════════════════════════════════════╝
|
| 98 |
+
"""
|
| 99 |
+
console.print(Panel(header, style="bold magenta"))
|
| 100 |
+
|
| 101 |
+
def _print_stage(self, stage: int, name: str, description: str):
|
| 102 |
+
"""Print a stage header."""
|
| 103 |
+
console.print(f"\n[bold cyan]{'='*60}[/]")
|
| 104 |
+
console.print(f"[bold cyan]STAGE {stage}: {name}[/]")
|
| 105 |
+
console.print(f"[dim]{description}[/]")
|
| 106 |
+
console.print(f"[bold cyan]{'='*60}[/]\n")
|
| 107 |
+
|
| 108 |
+
def run(
|
| 109 |
+
self,
|
| 110 |
+
data_path: str,
|
| 111 |
+
goal: str,
|
| 112 |
+
base_model: Optional[str] = None,
|
| 113 |
+
skip_training: bool = False,
|
| 114 |
+
skip_evaluation: bool = False,
|
| 115 |
+
judge_model: str = "gpt-4o",
|
| 116 |
+
num_eval_samples: int = 50
|
| 117 |
+
) -> Dict[str, Any]:
|
| 118 |
+
"""
|
| 119 |
+
Run the complete fine-tuning pipeline.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
data_path: Path to input dataset (CSV/JSON)
|
| 123 |
+
goal: Training goal/purpose
|
| 124 |
+
base_model: Override base model
|
| 125 |
+
skip_training: Skip training stage (use existing model)
|
| 126 |
+
skip_evaluation: Skip evaluation stage
|
| 127 |
+
judge_model: LLM to use as judge
|
| 128 |
+
num_eval_samples: Number of samples for evaluation
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Dict with pipeline results
|
| 132 |
+
"""
|
| 133 |
+
self._print_header()
|
| 134 |
+
|
| 135 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 136 |
+
run_name = f"{goal}_{timestamp}"
|
| 137 |
+
|
| 138 |
+
console.print(f"[bold]Run Name:[/] {run_name}")
|
| 139 |
+
console.print(f"[bold]Input Data:[/] {data_path}")
|
| 140 |
+
console.print(f"[bold]Goal:[/] {goal}")
|
| 141 |
+
console.print(f"[bold]Base Model:[/] {base_model or self.config.get('model', {}).get('base_model', 'unsloth/llama-3-8b-bnb-4bit')}")
|
| 142 |
+
|
| 143 |
+
results = {
|
| 144 |
+
"run_name": run_name,
|
| 145 |
+
"goal": goal,
|
| 146 |
+
"stages": {}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
# ═══════════════════════════════════════════════════════════
|
| 151 |
+
# STAGE 1: DATA PREPARATION
|
| 152 |
+
# ═══════════════════════════════════════════════════════════
|
| 153 |
+
self._print_stage(
|
| 154 |
+
1,
|
| 155 |
+
"DATA PREPARATION",
|
| 156 |
+
"Analyzing, cleaning, and formatting dataset for training"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Initialize data agent with config
|
| 160 |
+
data_config = self.config.get('data', {})
|
| 161 |
+
cleaning_config = CleaningConfig(
|
| 162 |
+
min_instruction_length=data_config.get('min_instruction_length', 10),
|
| 163 |
+
max_instruction_length=data_config.get('max_instruction_length', 2048),
|
| 164 |
+
min_response_length=data_config.get('min_response_length', 20),
|
| 165 |
+
max_response_length=data_config.get('max_response_length', 4096),
|
| 166 |
+
remove_duplicates=data_config.get('remove_duplicates', True),
|
| 167 |
+
quality_threshold=data_config.get('quality_threshold', 0.7)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.data_agent = DataArchitectAgent(config=cleaning_config)
|
| 171 |
+
|
| 172 |
+
# Process data
|
| 173 |
+
output_jsonl = self.output_dir / "processed_data" / f"{run_name}_training.jsonl"
|
| 174 |
+
self.processed_data_path, data_analysis = self.data_agent.process(
|
| 175 |
+
input_path=data_path,
|
| 176 |
+
output_path=str(output_jsonl),
|
| 177 |
+
goal=goal
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
results["stages"]["data_preparation"] = {
|
| 181 |
+
"status": "success",
|
| 182 |
+
"output_path": self.processed_data_path,
|
| 183 |
+
"total_samples": data_analysis.valid_rows,
|
| 184 |
+
"quality_score": data_analysis.quality_score
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
# ═══════════════════════════════════════════════════════════
|
| 188 |
+
# STAGE 2: FINE-TUNING
|
| 189 |
+
# ═══════════════════════════════════════════════════════════
|
| 190 |
+
if not skip_training:
|
| 191 |
+
self._print_stage(
|
| 192 |
+
2,
|
| 193 |
+
"FINE-TUNING",
|
| 194 |
+
"Auto-configuring hyperparameters and training with Unsloth"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Get model config
|
| 198 |
+
model_config = self.config.get('model', {})
|
| 199 |
+
base_model = base_model or model_config.get('base_model', 'unsloth/llama-3-8b-bnb-4bit')
|
| 200 |
+
max_seq_length = model_config.get('max_seq_length', 2048)
|
| 201 |
+
|
| 202 |
+
self.training_agent = TrainingPilot(
|
| 203 |
+
base_model=base_model,
|
| 204 |
+
max_seq_length=max_seq_length,
|
| 205 |
+
output_dir=str(self.output_dir / "models"),
|
| 206 |
+
config_path=None
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Run training
|
| 210 |
+
training_result = self.training_agent.run(
|
| 211 |
+
data_path=self.processed_data_path,
|
| 212 |
+
output_name=run_name
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
self.model_path = training_result.model_path
|
| 216 |
+
|
| 217 |
+
results["stages"]["training"] = {
|
| 218 |
+
"status": "success",
|
| 219 |
+
"model_path": self.model_path,
|
| 220 |
+
"training_time": training_result.training_time,
|
| 221 |
+
"final_loss": training_result.final_loss,
|
| 222 |
+
"hyperparams": training_result.hyperparams.to_dict()
|
| 223 |
+
}
|
| 224 |
+
else:
|
| 225 |
+
console.print("[yellow]⏭️ Skipping training stage[/]")
|
| 226 |
+
results["stages"]["training"] = {"status": "skipped"}
|
| 227 |
+
|
| 228 |
+
# ═══════════════════════════════════════════════════════════
|
| 229 |
+
# STAGE 3: EVALUATION
|
| 230 |
+
# ═══════════════════════════════════════════════════════════
|
| 231 |
+
if not skip_evaluation and self.model_path:
|
| 232 |
+
self._print_stage(
|
| 233 |
+
3,
|
| 234 |
+
"EVALUATION",
|
| 235 |
+
"Running Model Arena with LLM-as-Judge"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Check for API keys
|
| 239 |
+
eval_config = self.config.get('evaluation', {})
|
| 240 |
+
judge_model_str = judge_model or eval_config.get('judge_model', 'gpt-4o')
|
| 241 |
+
|
| 242 |
+
if judge_model_str == "gpt-4o" and not os.getenv("OPENAI_API_KEY"):
|
| 243 |
+
console.print("[yellow]⚠️ OPENAI_API_KEY not set. Skipping evaluation.[/]")
|
| 244 |
+
results["stages"]["evaluation"] = {
|
| 245 |
+
"status": "skipped",
|
| 246 |
+
"reason": "No API key"
|
| 247 |
+
}
|
| 248 |
+
elif "claude" in judge_model_str and not os.getenv("ANTHROPIC_API_KEY"):
|
| 249 |
+
console.print("[yellow]⚠️ ANTHROPIC_API_KEY not set. Skipping evaluation.[/]")
|
| 250 |
+
results["stages"]["evaluation"] = {
|
| 251 |
+
"status": "skipped",
|
| 252 |
+
"reason": "No API key"
|
| 253 |
+
}
|
| 254 |
+
else:
|
| 255 |
+
# Determine judge model enum
|
| 256 |
+
if "claude" in judge_model_str.lower():
|
| 257 |
+
judge_enum = JudgeModel.CLAUDE_35_SONNET
|
| 258 |
+
else:
|
| 259 |
+
judge_enum = JudgeModel.GPT4O
|
| 260 |
+
|
| 261 |
+
self.judge_agent = TheJudge(
|
| 262 |
+
judge_model=judge_enum,
|
| 263 |
+
temperature=eval_config.get('temperature', 0.2),
|
| 264 |
+
max_tokens=eval_config.get('max_tokens', 1024)
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Load models for evaluation
|
| 268 |
+
console.print("[blue]Loading models for evaluation...[/]")
|
| 269 |
+
|
| 270 |
+
try:
|
| 271 |
+
from unsloth import FastLanguageModel
|
| 272 |
+
|
| 273 |
+
# Load base model
|
| 274 |
+
base_model_name = base_model or self.config.get('model', {}).get('base_model', 'unsloth/llama-3-8b-bnb-4bit')
|
| 275 |
+
base_model_obj, base_tokenizer = FastLanguageModel.from_pretrained(
|
| 276 |
+
model_name=base_model_name,
|
| 277 |
+
max_seq_length=2048,
|
| 278 |
+
load_in_4bit=True,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Load fine-tuned model
|
| 282 |
+
ft_model, ft_tokenizer = FastLanguageModel.from_pretrained(
|
| 283 |
+
model_name=self.model_path,
|
| 284 |
+
max_seq_length=2048,
|
| 285 |
+
load_in_4bit=True,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Run evaluation
|
| 289 |
+
self.evaluation_result = self.judge_agent.run_with_test_data(
|
| 290 |
+
base_model=base_model_obj,
|
| 291 |
+
finetuned_model=ft_model,
|
| 292 |
+
tokenizer=base_tokenizer,
|
| 293 |
+
test_data_path=self.processed_data_path,
|
| 294 |
+
num_samples=num_eval_samples,
|
| 295 |
+
finetuned_tokenizer=ft_tokenizer
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Generate report
|
| 299 |
+
report_path = self.output_dir / "reports" / f"{run_name}_evaluation.json"
|
| 300 |
+
self.judge_agent.generate_report(
|
| 301 |
+
self.evaluation_result,
|
| 302 |
+
str(report_path)
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
results["stages"]["evaluation"] = {
|
| 306 |
+
"status": "success",
|
| 307 |
+
"win_rate": self.evaluation_result.win_rate,
|
| 308 |
+
"base_avg_score": self.evaluation_result.base_model_avg_score,
|
| 309 |
+
"finetuned_avg_score": self.evaluation_result.finetuned_avg_score,
|
| 310 |
+
"report_path": str(report_path)
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
except ImportError:
|
| 314 |
+
console.print("[yellow]⚠️ Unsloth not available for evaluation. Skipping.[/]")
|
| 315 |
+
results["stages"]["evaluation"] = {
|
| 316 |
+
"status": "skipped",
|
| 317 |
+
"reason": "Unsloth not available"
|
| 318 |
+
}
|
| 319 |
+
else:
|
| 320 |
+
if skip_evaluation:
|
| 321 |
+
console.print("[yellow]⏭️ Skipping evaluation stage[/]")
|
| 322 |
+
results["stages"]["evaluation"] = {"status": "skipped"}
|
| 323 |
+
|
| 324 |
+
# ═══════════════════════════════════════════════════════════
|
| 325 |
+
# STAGE 4: SUMMARY
|
| 326 |
+
# ═══════════════════════════════════════════════════════════
|
| 327 |
+
self._print_stage(
|
| 328 |
+
4,
|
| 329 |
+
"PIPELINE COMPLETE",
|
| 330 |
+
"Summary of the autonomous fine-tuning run"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
self._print_summary(results)
|
| 334 |
+
|
| 335 |
+
# Save results
|
| 336 |
+
results_path = self.output_dir / "logs" / f"{run_name}_results.yaml"
|
| 337 |
+
with open(results_path, 'w') as f:
|
| 338 |
+
yaml.dump(results, f, default_flow_style=False)
|
| 339 |
+
|
| 340 |
+
console.print(f"\n[green]✓ Results saved to: {results_path}[/]")
|
| 341 |
+
|
| 342 |
+
return results
|
| 343 |
+
|
| 344 |
+
except Exception as e:
|
| 345 |
+
console.print(f"\n[bold red]❌ Pipeline failed: {str(e)}[/]")
|
| 346 |
+
import traceback
|
| 347 |
+
traceback.print_exc()
|
| 348 |
+
results["error"] = str(e)
|
| 349 |
+
return results
|
| 350 |
+
|
| 351 |
+
def _print_summary(self, results: Dict[str, Any]):
|
| 352 |
+
"""Print pipeline summary."""
|
| 353 |
+
from rich.table import Table
|
| 354 |
+
|
| 355 |
+
table = Table(title="Pipeline Summary", show_header=True)
|
| 356 |
+
table.add_column("Stage", style="cyan")
|
| 357 |
+
table.add_column("Status", style="green")
|
| 358 |
+
table.add_column("Details", style="dim")
|
| 359 |
+
|
| 360 |
+
# Data preparation
|
| 361 |
+
data_stage = results["stages"].get("data_preparation", {})
|
| 362 |
+
if data_stage.get("status") == "success":
|
| 363 |
+
table.add_row(
|
| 364 |
+
"Data Preparation",
|
| 365 |
+
"✅ Success",
|
| 366 |
+
f"{data_stage.get('total_samples', 0):,} samples (Quality: {data_stage.get('quality_score', 0):.1%})"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Training
|
| 370 |
+
train_stage = results["stages"].get("training", {})
|
| 371 |
+
if train_stage.get("status") == "success":
|
| 372 |
+
table.add_row(
|
| 373 |
+
"Fine-Tuning",
|
| 374 |
+
"✅ Success",
|
| 375 |
+
f"Loss: {train_stage.get('final_loss', 0):.4f}"
|
| 376 |
+
)
|
| 377 |
+
elif train_stage.get("status") == "skipped":
|
| 378 |
+
table.add_row("Fine-Tuning", "⏭️ Skipped", "")
|
| 379 |
+
|
| 380 |
+
# Evaluation
|
| 381 |
+
eval_stage = results["stages"].get("evaluation", {})
|
| 382 |
+
if eval_stage.get("status") == "success":
|
| 383 |
+
table.add_row(
|
| 384 |
+
"Evaluation",
|
| 385 |
+
"✅ Success",
|
| 386 |
+
f"Win Rate: {eval_stage.get('win_rate', 0):.1%}"
|
| 387 |
+
)
|
| 388 |
+
elif eval_stage.get("status") == "skipped":
|
| 389 |
+
table.add_row("Evaluation", "⏭️ Skipped", eval_stage.get("reason", ""))
|
| 390 |
+
|
| 391 |
+
console.print(table)
|
| 392 |
+
|
| 393 |
+
# Print model path if available
|
| 394 |
+
if self.model_path:
|
| 395 |
+
console.print(f"\n[bold green]📦 Fine-tuned model saved to:[/]")
|
| 396 |
+
console.print(f" {self.model_path}")
|
| 397 |
+
console.print(f"\n[bold]To deploy, run:[/]")
|
| 398 |
+
console.print(f" [cyan]python scripts/deploy.py --model {self.model_path}[/]")
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def main():
|
| 402 |
+
"""CLI entry point."""
|
| 403 |
+
parser = argparse.ArgumentParser(
|
| 404 |
+
description="Auto-FineTune-Ops: One-click autonomous ML fine-tuning pipeline",
|
| 405 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 406 |
+
epilog="""
|
| 407 |
+
Examples:
|
| 408 |
+
python main.py --data ./data.csv --goal medical_assistant
|
| 409 |
+
python main.py --data ./qa_pairs.json --goal customer_support --model unsloth/llama-3-8b-bnb-4bit
|
| 410 |
+
python main.py --data ./dataset.jsonl --goal code_assistant --skip-eval
|
| 411 |
+
"""
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--data",
|
| 416 |
+
required=True,
|
| 417 |
+
help="Path to input dataset (CSV, JSON, or JSONL)"
|
| 418 |
+
)
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--goal",
|
| 421 |
+
required=True,
|
| 422 |
+
help="Training goal (e.g., medical_assistant, customer_support)"
|
| 423 |
+
)
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
"--model",
|
| 426 |
+
default=None,
|
| 427 |
+
help="Base model to fine-tune (default: unsloth/llama-3-8b-bnb-4bit)"
|
| 428 |
+
)
|
| 429 |
+
parser.add_argument(
|
| 430 |
+
"--config",
|
| 431 |
+
default=None,
|
| 432 |
+
help="Path to configuration YAML file"
|
| 433 |
+
)
|
| 434 |
+
parser.add_argument(
|
| 435 |
+
"--output",
|
| 436 |
+
default="./output",
|
| 437 |
+
help="Output directory (default: ./output)"
|
| 438 |
+
)
|
| 439 |
+
parser.add_argument(
|
| 440 |
+
"--skip-training",
|
| 441 |
+
action="store_true",
|
| 442 |
+
help="Skip training stage"
|
| 443 |
+
)
|
| 444 |
+
parser.add_argument(
|
| 445 |
+
"--skip-eval",
|
| 446 |
+
action="store_true",
|
| 447 |
+
help="Skip evaluation stage"
|
| 448 |
+
)
|
| 449 |
+
parser.add_argument(
|
| 450 |
+
"--judge",
|
| 451 |
+
choices=["gpt-4o", "claude-3-5-sonnet"],
|
| 452 |
+
default="gpt-4o",
|
| 453 |
+
help="Judge LLM for evaluation (default: gpt-4o)"
|
| 454 |
+
)
|
| 455 |
+
parser.add_argument(
|
| 456 |
+
"--eval-samples",
|
| 457 |
+
type=int,
|
| 458 |
+
default=50,
|
| 459 |
+
help="Number of samples for evaluation (default: 50)"
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
args = parser.parse_args()
|
| 463 |
+
|
| 464 |
+
# Run pipeline
|
| 465 |
+
orchestrator = AutoFineTuneOps(
|
| 466 |
+
config_path=args.config,
|
| 467 |
+
output_dir=args.output
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
orchestrator.run(
|
| 471 |
+
data_path=args.data,
|
| 472 |
+
goal=args.goal,
|
| 473 |
+
base_model=args.model,
|
| 474 |
+
skip_training=args.skip_training,
|
| 475 |
+
skip_evaluation=args.skip_eval,
|
| 476 |
+
judge_model=args.judge,
|
| 477 |
+
num_eval_samples=args.eval_samples
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
if __name__ == "__main__":
|
| 482 |
+
main()
|
preprocessing/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Preprocessing Pipeline for LLM Fine-Tuning
|
| 3 |
+
============================================
|
| 4 |
+
Modular preprocessing stages for cleaning, filtering,
|
| 5 |
+
formatting, and exporting datasets.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from preprocessing.pipeline import PreprocessingPipeline, PreprocessingConfig
|
preprocessing/augmentation.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Augmentation Module (Optional)
|
| 3 |
+
================================
|
| 4 |
+
Lightweight synthetic data expansion stubs.
|
| 5 |
+
These are pure-Python approximations. For production quality,
|
| 6 |
+
integrate with an LLM API or NLP library.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
import re
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import List
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class AugmentationConfig:
|
| 18 |
+
"""Configuration for data augmentation."""
|
| 19 |
+
enabled: bool = False
|
| 20 |
+
paraphrase: bool = False
|
| 21 |
+
generate_variations: bool = False
|
| 22 |
+
back_translate: bool = False
|
| 23 |
+
tone_rewrite: bool = False
|
| 24 |
+
augmentation_factor: int = 1 # how many extra copies per sample
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Synonym map for lightweight paraphrasing
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
_SYNONYMS = {
|
| 31 |
+
'explain': ['describe', 'elaborate on', 'clarify', 'break down'],
|
| 32 |
+
'create': ['generate', 'produce', 'make', 'build'],
|
| 33 |
+
'write': ['compose', 'draft', 'author', 'pen'],
|
| 34 |
+
'list': ['enumerate', 'outline', 'itemize', 'catalog'],
|
| 35 |
+
'help': ['assist', 'aid', 'support', 'guide'],
|
| 36 |
+
'show': ['demonstrate', 'display', 'present', 'illustrate'],
|
| 37 |
+
'tell': ['inform', 'describe', 'narrate', 'share'],
|
| 38 |
+
'give': ['provide', 'supply', 'offer', 'deliver'],
|
| 39 |
+
'find': ['locate', 'discover', 'identify', 'search for'],
|
| 40 |
+
'use': ['utilize', 'employ', 'apply', 'leverage'],
|
| 41 |
+
'what': ['which', 'what exactly'],
|
| 42 |
+
'how': ['in what way', 'by what method'],
|
| 43 |
+
'important': ['crucial', 'essential', 'significant', 'vital'],
|
| 44 |
+
'good': ['excellent', 'great', 'effective', 'beneficial'],
|
| 45 |
+
'bad': ['poor', 'negative', 'harmful', 'detrimental'],
|
| 46 |
+
'big': ['large', 'significant', 'substantial', 'major'],
|
| 47 |
+
'small': ['minor', 'slight', 'modest', 'minimal'],
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def paraphrase_instruction(text: str) -> str:
|
| 52 |
+
"""
|
| 53 |
+
Simple synonym-based paraphrasing.
|
| 54 |
+
Replaces one random word with a synonym.
|
| 55 |
+
"""
|
| 56 |
+
if not isinstance(text, str) or len(text.strip()) < 5:
|
| 57 |
+
return text
|
| 58 |
+
|
| 59 |
+
words = text.split()
|
| 60 |
+
candidates = []
|
| 61 |
+
|
| 62 |
+
for i, word in enumerate(words):
|
| 63 |
+
word_lower = word.lower().strip('.,!?;:')
|
| 64 |
+
if word_lower in _SYNONYMS:
|
| 65 |
+
candidates.append((i, word_lower))
|
| 66 |
+
|
| 67 |
+
if not candidates:
|
| 68 |
+
return text
|
| 69 |
+
|
| 70 |
+
idx, orig_word = random.choice(candidates)
|
| 71 |
+
replacement = random.choice(_SYNONYMS[orig_word])
|
| 72 |
+
|
| 73 |
+
# Preserve original casing
|
| 74 |
+
if words[idx][0].isupper():
|
| 75 |
+
replacement = replacement.capitalize()
|
| 76 |
+
|
| 77 |
+
# Preserve trailing punctuation
|
| 78 |
+
trailing = ''
|
| 79 |
+
if words[idx] and words[idx][-1] in '.,!?;:':
|
| 80 |
+
trailing = words[idx][-1]
|
| 81 |
+
words[idx] = replacement + trailing
|
| 82 |
+
else:
|
| 83 |
+
words[idx] = replacement
|
| 84 |
+
|
| 85 |
+
return ' '.join(words)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def generate_variation(text: str) -> str:
|
| 89 |
+
"""
|
| 90 |
+
Generate a minor variation of the text:
|
| 91 |
+
- Random case changes
|
| 92 |
+
- Add/remove trailing punctuation
|
| 93 |
+
- Slight word reordering at clause boundaries
|
| 94 |
+
"""
|
| 95 |
+
if not isinstance(text, str) or len(text.strip()) < 5:
|
| 96 |
+
return text
|
| 97 |
+
|
| 98 |
+
variations = [
|
| 99 |
+
lambda t: t.rstrip('.!?') + random.choice(['.', '!', '?', '']),
|
| 100 |
+
lambda t: t[0].upper() + t[1:] if len(t) > 1 else t,
|
| 101 |
+
lambda t: re.sub(r'\s+', ' ', t).strip(),
|
| 102 |
+
lambda t: t + ' Please be detailed.' if random.random() > 0.5 else t,
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
variation = random.choice(variations)
|
| 106 |
+
return variation(text)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def back_translate(text: str) -> str:
|
| 110 |
+
"""
|
| 111 |
+
Stub for back-translation.
|
| 112 |
+
In production, this would translate to another language and back.
|
| 113 |
+
Here we just do a light paraphrase.
|
| 114 |
+
"""
|
| 115 |
+
return paraphrase_instruction(text)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def rewrite_tone(text: str, tone: str = "formal") -> str:
|
| 119 |
+
"""
|
| 120 |
+
Stub for tone rewriting.
|
| 121 |
+
"""
|
| 122 |
+
tone_prefixes = {
|
| 123 |
+
'formal': 'Please ',
|
| 124 |
+
'casual': 'Hey, can you ',
|
| 125 |
+
'academic': 'Kindly provide a detailed analysis of ',
|
| 126 |
+
'friendly': 'I would really appreciate if you could ',
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
prefix = tone_prefixes.get(tone, '')
|
| 130 |
+
|
| 131 |
+
# Don't double-prefix
|
| 132 |
+
if text.lower().startswith(prefix.lower().strip()):
|
| 133 |
+
return text
|
| 134 |
+
|
| 135 |
+
# Simple approach: prepend tone prefix if the text starts with a verb-like word
|
| 136 |
+
first_word = text.split()[0].lower() if text.split() else ''
|
| 137 |
+
action_words = {'explain', 'describe', 'write', 'create', 'list', 'show', 'tell', 'give', 'find', 'help', 'make'}
|
| 138 |
+
|
| 139 |
+
if first_word in action_words:
|
| 140 |
+
return prefix + text[0].lower() + text[1:]
|
| 141 |
+
|
| 142 |
+
return text
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def augment_dataset(
|
| 146 |
+
df: pd.DataFrame,
|
| 147 |
+
col: str,
|
| 148 |
+
config: AugmentationConfig,
|
| 149 |
+
) -> pd.DataFrame:
|
| 150 |
+
"""
|
| 151 |
+
Apply augmentation to create additional samples.
|
| 152 |
+
Returns the original + augmented samples.
|
| 153 |
+
"""
|
| 154 |
+
if not config.enabled:
|
| 155 |
+
return df
|
| 156 |
+
|
| 157 |
+
methods = []
|
| 158 |
+
if config.paraphrase:
|
| 159 |
+
methods.append(paraphrase_instruction)
|
| 160 |
+
if config.generate_variations:
|
| 161 |
+
methods.append(generate_variation)
|
| 162 |
+
if config.back_translate:
|
| 163 |
+
methods.append(back_translate)
|
| 164 |
+
if config.tone_rewrite:
|
| 165 |
+
methods.append(lambda t: rewrite_tone(t, "formal"))
|
| 166 |
+
|
| 167 |
+
if not methods:
|
| 168 |
+
return df
|
| 169 |
+
|
| 170 |
+
new_rows = []
|
| 171 |
+
for _, row in df.iterrows():
|
| 172 |
+
for _ in range(config.augmentation_factor):
|
| 173 |
+
method = random.choice(methods)
|
| 174 |
+
new_row = row.copy()
|
| 175 |
+
new_row[col] = method(str(row[col]))
|
| 176 |
+
new_rows.append(new_row)
|
| 177 |
+
|
| 178 |
+
if new_rows:
|
| 179 |
+
augmented = pd.DataFrame(new_rows)
|
| 180 |
+
return pd.concat([df, augmented], ignore_index=True)
|
| 181 |
+
|
| 182 |
+
return df
|
preprocessing/dataset_balancing.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Balancing Module
|
| 3 |
+
=========================
|
| 4 |
+
Class balancing for classification datasets via
|
| 5 |
+
oversampling / undersampling strategies.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class BalancingConfig:
|
| 15 |
+
"""Configuration for dataset balancing."""
|
| 16 |
+
enabled: bool = False
|
| 17 |
+
label_column: str = ""
|
| 18 |
+
strategy: str = "none" # "none", "oversample", "undersample"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def compute_label_distribution(
|
| 22 |
+
df: pd.DataFrame,
|
| 23 |
+
label_col: str,
|
| 24 |
+
) -> Dict[str, int]:
|
| 25 |
+
"""
|
| 26 |
+
Compute label distribution for a given column.
|
| 27 |
+
Returns dict of label_value -> count.
|
| 28 |
+
"""
|
| 29 |
+
if label_col not in df.columns:
|
| 30 |
+
return {}
|
| 31 |
+
return df[label_col].value_counts().to_dict()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def oversample_minority(
|
| 35 |
+
df: pd.DataFrame,
|
| 36 |
+
label_col: str,
|
| 37 |
+
) -> pd.DataFrame:
|
| 38 |
+
"""
|
| 39 |
+
Oversample minority classes to match the majority class count.
|
| 40 |
+
"""
|
| 41 |
+
if label_col not in df.columns:
|
| 42 |
+
return df
|
| 43 |
+
|
| 44 |
+
counts = df[label_col].value_counts()
|
| 45 |
+
max_count = counts.max()
|
| 46 |
+
|
| 47 |
+
frames = []
|
| 48 |
+
for label, count in counts.items():
|
| 49 |
+
label_df = df[df[label_col] == label]
|
| 50 |
+
if count < max_count:
|
| 51 |
+
# Resample with replacement to reach max_count
|
| 52 |
+
extra = label_df.sample(n=max_count - count, replace=True, random_state=42)
|
| 53 |
+
frames.append(pd.concat([label_df, extra], ignore_index=True))
|
| 54 |
+
else:
|
| 55 |
+
frames.append(label_df)
|
| 56 |
+
|
| 57 |
+
return pd.concat(frames, ignore_index=True)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def undersample_majority(
|
| 61 |
+
df: pd.DataFrame,
|
| 62 |
+
label_col: str,
|
| 63 |
+
) -> pd.DataFrame:
|
| 64 |
+
"""
|
| 65 |
+
Undersample majority classes to match the minority class count.
|
| 66 |
+
"""
|
| 67 |
+
if label_col not in df.columns:
|
| 68 |
+
return df
|
| 69 |
+
|
| 70 |
+
counts = df[label_col].value_counts()
|
| 71 |
+
min_count = counts.min()
|
| 72 |
+
|
| 73 |
+
frames = []
|
| 74 |
+
for label in counts.index:
|
| 75 |
+
label_df = df[df[label_col] == label]
|
| 76 |
+
if len(label_df) > min_count:
|
| 77 |
+
frames.append(label_df.sample(n=min_count, random_state=42))
|
| 78 |
+
else:
|
| 79 |
+
frames.append(label_df)
|
| 80 |
+
|
| 81 |
+
return pd.concat(frames, ignore_index=True)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def balance_dataset(
|
| 85 |
+
df: pd.DataFrame,
|
| 86 |
+
label_col: str,
|
| 87 |
+
strategy: str = "none",
|
| 88 |
+
) -> pd.DataFrame:
|
| 89 |
+
"""
|
| 90 |
+
Balance dataset using the specified strategy.
|
| 91 |
+
strategy: "none", "oversample", or "undersample"
|
| 92 |
+
"""
|
| 93 |
+
if strategy == "oversample":
|
| 94 |
+
return oversample_minority(df, label_col)
|
| 95 |
+
elif strategy == "undersample":
|
| 96 |
+
return undersample_majority(df, label_col)
|
| 97 |
+
return df
|
preprocessing/deduplication.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Deduplication Module
|
| 3 |
+
======================
|
| 4 |
+
Exact and semantic (TF-IDF cosine similarity) deduplication.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class DeduplicationConfig:
|
| 15 |
+
"""Configuration for deduplication."""
|
| 16 |
+
remove_exact: bool = True
|
| 17 |
+
remove_semantic: bool = False
|
| 18 |
+
semantic_threshold: float = 0.90 # cosine similarity threshold
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def remove_exact_duplicates(
|
| 22 |
+
df: pd.DataFrame,
|
| 23 |
+
col: str,
|
| 24 |
+
) -> pd.DataFrame:
|
| 25 |
+
"""Remove rows with exact duplicate values in the given column."""
|
| 26 |
+
return df.drop_duplicates(subset=[col]).reset_index(drop=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def remove_semantic_duplicates(
|
| 30 |
+
df: pd.DataFrame,
|
| 31 |
+
col: str,
|
| 32 |
+
threshold: float = 0.90,
|
| 33 |
+
) -> pd.DataFrame:
|
| 34 |
+
"""
|
| 35 |
+
Remove semantically similar rows using TF-IDF cosine similarity.
|
| 36 |
+
Rows with cosine similarity >= threshold to an earlier row are dropped.
|
| 37 |
+
"""
|
| 38 |
+
if len(df) < 2:
|
| 39 |
+
return df
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 43 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 44 |
+
except ImportError:
|
| 45 |
+
# If scikit-learn not available, just return as-is
|
| 46 |
+
return df
|
| 47 |
+
|
| 48 |
+
texts = df[col].fillna('').astype(str).tolist()
|
| 49 |
+
|
| 50 |
+
# Build TF-IDF matrix
|
| 51 |
+
vectorizer = TfidfVectorizer(max_features=5000, stop_words='english')
|
| 52 |
+
try:
|
| 53 |
+
tfidf_matrix = vectorizer.fit_transform(texts)
|
| 54 |
+
except ValueError:
|
| 55 |
+
return df
|
| 56 |
+
|
| 57 |
+
# Find duplicates — compare each row to all previous rows
|
| 58 |
+
keep_indices = [0]
|
| 59 |
+
|
| 60 |
+
for i in range(1, len(texts)):
|
| 61 |
+
# Compare row i against all kept rows
|
| 62 |
+
sim = cosine_similarity(
|
| 63 |
+
tfidf_matrix[i:i+1],
|
| 64 |
+
tfidf_matrix[keep_indices],
|
| 65 |
+
)
|
| 66 |
+
if sim.max() < threshold:
|
| 67 |
+
keep_indices.append(i)
|
| 68 |
+
|
| 69 |
+
return df.iloc[keep_indices].reset_index(drop=True)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def apply_deduplication(
|
| 73 |
+
df: pd.DataFrame,
|
| 74 |
+
col: str,
|
| 75 |
+
config: DeduplicationConfig,
|
| 76 |
+
) -> pd.DataFrame:
|
| 77 |
+
"""Apply all enabled deduplication methods."""
|
| 78 |
+
if config.remove_exact:
|
| 79 |
+
df = remove_exact_duplicates(df, col)
|
| 80 |
+
|
| 81 |
+
if config.remove_semantic:
|
| 82 |
+
df = remove_semantic_duplicates(df, col, config.semantic_threshold)
|
| 83 |
+
|
| 84 |
+
return df
|
preprocessing/output_formatter.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Output Formatter Module
|
| 3 |
+
=========================
|
| 4 |
+
Export datasets in multiple JSONL formats:
|
| 5 |
+
- OpenAI Chat JSONL
|
| 6 |
+
- Completion JSONL
|
| 7 |
+
- Classification JSONL
|
| 8 |
+
- Custom schema JSONL
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
from typing import List, Dict, Any, Optional
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class OutputFormatConfig:
|
| 20 |
+
"""Configuration for output formatting."""
|
| 21 |
+
format_type: str = "openai_chat" # "openai_chat", "completion", "classification", "custom"
|
| 22 |
+
custom_schema: Dict[str, str] = field(default_factory=dict)
|
| 23 |
+
# custom_schema maps output_key -> source_column, e.g. {"text": "instruction", "label": "category"}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def format_openai_chat(
|
| 27 |
+
df: pd.DataFrame,
|
| 28 |
+
system_prompt: str,
|
| 29 |
+
instruction_col: str,
|
| 30 |
+
output_col: str,
|
| 31 |
+
input_col: Optional[str] = None,
|
| 32 |
+
) -> List[Dict[str, Any]]:
|
| 33 |
+
"""
|
| 34 |
+
Format as OpenAI Chat JSONL.
|
| 35 |
+
Each entry: {"messages": [{"role": "system", ...}, {"role": "user", ...}, {"role": "assistant", ...}]}
|
| 36 |
+
"""
|
| 37 |
+
data = []
|
| 38 |
+
for _, row in df.iterrows():
|
| 39 |
+
messages = []
|
| 40 |
+
if system_prompt:
|
| 41 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 42 |
+
|
| 43 |
+
user_content = str(row[instruction_col])
|
| 44 |
+
if input_col and input_col in df.columns:
|
| 45 |
+
context = str(row.get(input_col, ''))
|
| 46 |
+
if context and context != 'nan':
|
| 47 |
+
user_content += f"\n\nContext: {context}"
|
| 48 |
+
|
| 49 |
+
messages.append({"role": "user", "content": user_content})
|
| 50 |
+
messages.append({"role": "assistant", "content": str(row[output_col])})
|
| 51 |
+
|
| 52 |
+
data.append({"messages": messages})
|
| 53 |
+
return data
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def format_completion(
|
| 57 |
+
df: pd.DataFrame,
|
| 58 |
+
instruction_col: str,
|
| 59 |
+
output_col: str,
|
| 60 |
+
) -> List[Dict[str, Any]]:
|
| 61 |
+
"""
|
| 62 |
+
Format as Completion JSONL.
|
| 63 |
+
Each entry: {"prompt": "...", "completion": "..."}
|
| 64 |
+
"""
|
| 65 |
+
data = []
|
| 66 |
+
for _, row in df.iterrows():
|
| 67 |
+
data.append({
|
| 68 |
+
"prompt": str(row[instruction_col]),
|
| 69 |
+
"completion": str(row[output_col]),
|
| 70 |
+
})
|
| 71 |
+
return data
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def format_classification(
|
| 75 |
+
df: pd.DataFrame,
|
| 76 |
+
text_col: str,
|
| 77 |
+
label_col: str,
|
| 78 |
+
) -> List[Dict[str, Any]]:
|
| 79 |
+
"""
|
| 80 |
+
Format as Classification JSONL.
|
| 81 |
+
Each entry: {"text": "...", "label": "..."}
|
| 82 |
+
"""
|
| 83 |
+
data = []
|
| 84 |
+
for _, row in df.iterrows():
|
| 85 |
+
data.append({
|
| 86 |
+
"text": str(row[text_col]),
|
| 87 |
+
"label": str(row[label_col]),
|
| 88 |
+
})
|
| 89 |
+
return data
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def format_custom(
|
| 93 |
+
df: pd.DataFrame,
|
| 94 |
+
schema: Dict[str, str],
|
| 95 |
+
) -> List[Dict[str, Any]]:
|
| 96 |
+
"""
|
| 97 |
+
Format using a custom schema.
|
| 98 |
+
schema: dict mapping output_key -> source_column name
|
| 99 |
+
"""
|
| 100 |
+
data = []
|
| 101 |
+
for _, row in df.iterrows():
|
| 102 |
+
entry = {}
|
| 103 |
+
for out_key, src_col in schema.items():
|
| 104 |
+
if src_col in df.columns:
|
| 105 |
+
entry[out_key] = str(row[src_col])
|
| 106 |
+
else:
|
| 107 |
+
entry[out_key] = ""
|
| 108 |
+
data.append(entry)
|
| 109 |
+
return data
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def export_jsonl(data: List[Dict[str, Any]], path: str) -> str:
|
| 113 |
+
"""Write a list of dicts as JSONL to a file."""
|
| 114 |
+
output_path = Path(path)
|
| 115 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 116 |
+
|
| 117 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 118 |
+
for entry in data:
|
| 119 |
+
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
| 120 |
+
|
| 121 |
+
return str(output_path)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def generate_preview(data: List[Dict[str, Any]], n: int = 3) -> str:
|
| 125 |
+
"""Return a pretty-printed JSON string of the first n entries."""
|
| 126 |
+
return json.dumps(data[:n], indent=2, ensure_ascii=False)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def format_dataset(
|
| 130 |
+
df: pd.DataFrame,
|
| 131 |
+
config: OutputFormatConfig,
|
| 132 |
+
system_prompt: str = "",
|
| 133 |
+
instruction_col: str = "",
|
| 134 |
+
output_col: str = "",
|
| 135 |
+
input_col: Optional[str] = None,
|
| 136 |
+
label_col: Optional[str] = None,
|
| 137 |
+
) -> List[Dict[str, Any]]:
|
| 138 |
+
"""Format the dataset according to the configured format type."""
|
| 139 |
+
if config.format_type == "openai_chat":
|
| 140 |
+
return format_openai_chat(df, system_prompt, instruction_col, output_col, input_col)
|
| 141 |
+
elif config.format_type == "completion":
|
| 142 |
+
return format_completion(df, instruction_col, output_col)
|
| 143 |
+
elif config.format_type == "classification":
|
| 144 |
+
text_col = instruction_col or (list(df.columns)[0] if len(df.columns) > 0 else "")
|
| 145 |
+
lbl_col = label_col or output_col
|
| 146 |
+
return format_classification(df, text_col, lbl_col)
|
| 147 |
+
elif config.format_type == "custom":
|
| 148 |
+
return format_custom(df, config.custom_schema)
|
| 149 |
+
else:
|
| 150 |
+
return format_openai_chat(df, system_prompt, instruction_col, output_col, input_col)
|
preprocessing/pii_filter.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PII (Personally Identifiable Information) Filter Module
|
| 3 |
+
=========================================================
|
| 4 |
+
Regex-based detection and masking for emails, phone numbers,
|
| 5 |
+
CNIC/SSN-like patterns, API keys, and addresses.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import List, Dict, Tuple
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class PIIFilterConfig:
|
| 16 |
+
"""Configuration for PII filtering."""
|
| 17 |
+
filter_emails: bool = False
|
| 18 |
+
filter_phones: bool = False
|
| 19 |
+
filter_id_numbers: bool = False # CNIC / SSN patterns
|
| 20 |
+
filter_api_keys: bool = False
|
| 21 |
+
filter_addresses: bool = False
|
| 22 |
+
mask_char: str = "[REDACTED]"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Detection + Masking patterns
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
_EMAIL_PATTERN = re.compile(
|
| 30 |
+
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
_PHONE_PATTERN = re.compile(
|
| 34 |
+
r'(?:\+?\d{1,3}[-.\s]?)?\(?\d{2,4}\)?[-.\s]?\d{3,4}[-.\s]?\d{3,4}'
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# SSN: 123-45-6789, CNIC: 12345-1234567-1
|
| 38 |
+
_ID_NUMBER_PATTERN = re.compile(
|
| 39 |
+
r'\b\d{3}-\d{2}-\d{4}\b' # US SSN
|
| 40 |
+
r'|\b\d{5}-\d{7}-\d{1}\b' # PK CNIC
|
| 41 |
+
r'|\b\d{13}\b' # 13-digit ID
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Long hex or base64 strings that look like API keys / secrets
|
| 45 |
+
_API_KEY_PATTERN = re.compile(
|
| 46 |
+
r'\b(?:sk|pk|api|key|secret|token)[_-]?[A-Za-z0-9]{20,}\b'
|
| 47 |
+
r'|[A-Fa-f0-9]{32,}'
|
| 48 |
+
r'|[A-Za-z0-9+/]{40,}={0,2}',
|
| 49 |
+
re.IGNORECASE,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Basic address patterns (US-style zip, PO Box, street numbers)
|
| 53 |
+
_ADDRESS_PATTERN = re.compile(
|
| 54 |
+
r'\b\d{1,5}\s+\w+\s+(?:St|Street|Ave|Avenue|Blvd|Boulevard|Dr|Drive|Rd|Road|Ln|Lane|Way|Ct|Court)\b'
|
| 55 |
+
r'|\bP\.?O\.?\s*Box\s+\d+\b'
|
| 56 |
+
r'|\b\d{5}(?:-\d{4})?\b', # Zip code
|
| 57 |
+
re.IGNORECASE,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def detect_emails(text: str) -> List[str]:
|
| 62 |
+
"""Find all email addresses in text."""
|
| 63 |
+
return _EMAIL_PATTERN.findall(text) if isinstance(text, str) else []
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def mask_emails(text: str, mask: str = "[REDACTED_EMAIL]") -> str:
|
| 67 |
+
"""Replace email addresses with mask."""
|
| 68 |
+
return _EMAIL_PATTERN.sub(mask, text) if isinstance(text, str) else text
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def detect_phones(text: str) -> List[str]:
|
| 72 |
+
"""Find all phone numbers in text."""
|
| 73 |
+
return _PHONE_PATTERN.findall(text) if isinstance(text, str) else []
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def mask_phones(text: str, mask: str = "[REDACTED_PHONE]") -> str:
|
| 77 |
+
"""Replace phone numbers with mask."""
|
| 78 |
+
return _PHONE_PATTERN.sub(mask, text) if isinstance(text, str) else text
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def detect_id_numbers(text: str) -> List[str]:
|
| 82 |
+
"""Find SSN/CNIC-like patterns in text."""
|
| 83 |
+
return _ID_NUMBER_PATTERN.findall(text) if isinstance(text, str) else []
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def mask_id_numbers(text: str, mask: str = "[REDACTED_ID]") -> str:
|
| 87 |
+
"""Replace ID number patterns with mask."""
|
| 88 |
+
return _ID_NUMBER_PATTERN.sub(mask, text) if isinstance(text, str) else text
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def detect_api_keys(text: str) -> List[str]:
|
| 92 |
+
"""Find API key / secret patterns in text."""
|
| 93 |
+
return _API_KEY_PATTERN.findall(text) if isinstance(text, str) else []
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def mask_api_keys(text: str, mask: str = "[REDACTED_KEY]") -> str:
|
| 97 |
+
"""Replace API key patterns with mask."""
|
| 98 |
+
return _API_KEY_PATTERN.sub(mask, text) if isinstance(text, str) else text
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def detect_addresses(text: str) -> List[str]:
|
| 102 |
+
"""Find address-like patterns in text."""
|
| 103 |
+
return _ADDRESS_PATTERN.findall(text) if isinstance(text, str) else []
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def mask_addresses(text: str, mask: str = "[REDACTED_ADDR]") -> str:
|
| 107 |
+
"""Replace address patterns with mask."""
|
| 108 |
+
return _ADDRESS_PATTERN.sub(mask, text) if isinstance(text, str) else text
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def apply_pii_filter(
|
| 112 |
+
text: str,
|
| 113 |
+
config: PIIFilterConfig,
|
| 114 |
+
) -> str:
|
| 115 |
+
"""Apply all enabled PII filters to a single text string."""
|
| 116 |
+
mask = config.mask_char
|
| 117 |
+
|
| 118 |
+
if config.filter_emails:
|
| 119 |
+
text = mask_emails(text, mask)
|
| 120 |
+
if config.filter_phones:
|
| 121 |
+
text = mask_phones(text, mask)
|
| 122 |
+
if config.filter_id_numbers:
|
| 123 |
+
text = mask_id_numbers(text, mask)
|
| 124 |
+
if config.filter_api_keys:
|
| 125 |
+
text = mask_api_keys(text, mask)
|
| 126 |
+
if config.filter_addresses:
|
| 127 |
+
text = mask_addresses(text, mask)
|
| 128 |
+
|
| 129 |
+
return text
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def apply_pii_filter_df(
|
| 133 |
+
df: pd.DataFrame,
|
| 134 |
+
columns: List[str],
|
| 135 |
+
config: PIIFilterConfig,
|
| 136 |
+
) -> pd.DataFrame:
|
| 137 |
+
"""Apply PII filtering to specified columns of a DataFrame."""
|
| 138 |
+
df = df.copy()
|
| 139 |
+
for col in columns:
|
| 140 |
+
if col in df.columns:
|
| 141 |
+
df[col] = df[col].apply(lambda t: apply_pii_filter(str(t), config))
|
| 142 |
+
return df
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def detect_pii_summary(
|
| 146 |
+
df: pd.DataFrame,
|
| 147 |
+
columns: List[str],
|
| 148 |
+
) -> Dict[str, int]:
|
| 149 |
+
"""
|
| 150 |
+
Scan columns and count PII instances found.
|
| 151 |
+
Returns dict like {"emails": 5, "phones": 2, ...}.
|
| 152 |
+
"""
|
| 153 |
+
summary = {"emails": 0, "phones": 0, "id_numbers": 0, "api_keys": 0, "addresses": 0}
|
| 154 |
+
|
| 155 |
+
for col in columns:
|
| 156 |
+
if col not in df.columns:
|
| 157 |
+
continue
|
| 158 |
+
for text in df[col].astype(str):
|
| 159 |
+
summary["emails"] += len(detect_emails(text))
|
| 160 |
+
summary["phones"] += len(detect_phones(text))
|
| 161 |
+
summary["id_numbers"] += len(detect_id_numbers(text))
|
| 162 |
+
summary["api_keys"] += len(detect_api_keys(text))
|
| 163 |
+
summary["addresses"] += len(detect_addresses(text))
|
| 164 |
+
|
| 165 |
+
return summary
|
preprocessing/pipeline.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Preprocessing Pipeline Runner
|
| 3 |
+
================================
|
| 4 |
+
Central pipeline that runs all enabled preprocessing stages
|
| 5 |
+
sequentially and logs each step.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 10 |
+
import time
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
from preprocessing.text_cleaning import TextCleaningConfig, apply_text_cleaning
|
| 14 |
+
from preprocessing.tokenization import (
|
| 15 |
+
TokenizationConfig, get_tokenizer, compute_token_stats,
|
| 16 |
+
truncate_samples, split_long_samples,
|
| 17 |
+
)
|
| 18 |
+
from preprocessing.system_prompt import SystemPromptConfig
|
| 19 |
+
from preprocessing.dataset_balancing import BalancingConfig, balance_dataset
|
| 20 |
+
from preprocessing.quality_filters import QualityFilterConfig, apply_quality_filters
|
| 21 |
+
from preprocessing.deduplication import DeduplicationConfig, apply_deduplication
|
| 22 |
+
from preprocessing.train_val_split import SplitConfig, split_dataset
|
| 23 |
+
from preprocessing.output_formatter import OutputFormatConfig, format_dataset, export_jsonl
|
| 24 |
+
from preprocessing.pii_filter import PIIFilterConfig, apply_pii_filter_df
|
| 25 |
+
from preprocessing.augmentation import AugmentationConfig, augment_dataset
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class PreprocessingConfig:
|
| 30 |
+
"""Master configuration for the entire preprocessing pipeline."""
|
| 31 |
+
# Column mappings
|
| 32 |
+
instruction_col: str = ""
|
| 33 |
+
output_col: str = ""
|
| 34 |
+
input_col: Optional[str] = None
|
| 35 |
+
label_col: Optional[str] = None
|
| 36 |
+
|
| 37 |
+
# Sub-configs
|
| 38 |
+
text_cleaning: TextCleaningConfig = field(default_factory=TextCleaningConfig)
|
| 39 |
+
tokenization: TokenizationConfig = field(default_factory=TokenizationConfig)
|
| 40 |
+
system_prompt: SystemPromptConfig = field(default_factory=SystemPromptConfig)
|
| 41 |
+
balancing: BalancingConfig = field(default_factory=BalancingConfig)
|
| 42 |
+
quality_filters: QualityFilterConfig = field(default_factory=QualityFilterConfig)
|
| 43 |
+
deduplication: DeduplicationConfig = field(default_factory=DeduplicationConfig)
|
| 44 |
+
split: SplitConfig = field(default_factory=SplitConfig)
|
| 45 |
+
output_format: OutputFormatConfig = field(default_factory=OutputFormatConfig)
|
| 46 |
+
pii_filter: PIIFilterConfig = field(default_factory=PIIFilterConfig)
|
| 47 |
+
augmentation: AugmentationConfig = field(default_factory=AugmentationConfig)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@dataclass
|
| 51 |
+
class PipelineLog:
|
| 52 |
+
"""A single log entry from a pipeline stage."""
|
| 53 |
+
stage: str
|
| 54 |
+
description: str
|
| 55 |
+
rows_before: int
|
| 56 |
+
rows_after: int
|
| 57 |
+
duration_ms: float
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def rows_delta(self) -> int:
|
| 61 |
+
return self.rows_after - self.rows_before
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class PreprocessingPipeline:
|
| 65 |
+
"""
|
| 66 |
+
Sequential preprocessing pipeline runner.
|
| 67 |
+
Applies all enabled stages and collects logs.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, config: PreprocessingConfig):
|
| 71 |
+
self.config = config
|
| 72 |
+
self.logs: List[PipelineLog] = []
|
| 73 |
+
|
| 74 |
+
def _log(self, stage: str, desc: str, before: int, after: int, elapsed: float):
|
| 75 |
+
self.logs.append(PipelineLog(
|
| 76 |
+
stage=stage,
|
| 77 |
+
description=desc,
|
| 78 |
+
rows_before=before,
|
| 79 |
+
rows_after=after,
|
| 80 |
+
duration_ms=round(elapsed * 1000, 1),
|
| 81 |
+
))
|
| 82 |
+
|
| 83 |
+
def run(
|
| 84 |
+
self,
|
| 85 |
+
df: pd.DataFrame,
|
| 86 |
+
progress_callback=None,
|
| 87 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame, List[PipelineLog]]:
|
| 88 |
+
"""
|
| 89 |
+
Run the complete preprocessing pipeline.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
df: Input DataFrame
|
| 93 |
+
progress_callback: Optional callable(stage_name, progress_pct) for UI updates
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
(train_df, val_df, logs)
|
| 97 |
+
If split is disabled, val_df will be empty.
|
| 98 |
+
"""
|
| 99 |
+
self.logs = []
|
| 100 |
+
total_stages = 7 # text cleaning, quality, dedup, pii, balancing, augmentation, tokenization
|
| 101 |
+
current_stage = 0
|
| 102 |
+
|
| 103 |
+
def _progress(name):
|
| 104 |
+
nonlocal current_stage
|
| 105 |
+
current_stage += 1
|
| 106 |
+
if progress_callback:
|
| 107 |
+
pct = int((current_stage / total_stages) * 100)
|
| 108 |
+
progress_callback(name, pct)
|
| 109 |
+
|
| 110 |
+
cfg = self.config
|
| 111 |
+
text_cols = [c for c in [cfg.instruction_col, cfg.output_col, cfg.input_col] if c and c in df.columns]
|
| 112 |
+
|
| 113 |
+
# ── Stage 1: Text Cleaning ──
|
| 114 |
+
t0 = time.time()
|
| 115 |
+
before = len(df)
|
| 116 |
+
any_cleaning = (
|
| 117 |
+
cfg.text_cleaning.remove_html or cfg.text_cleaning.remove_urls or
|
| 118 |
+
cfg.text_cleaning.remove_emojis or cfg.text_cleaning.normalize_whitespace or
|
| 119 |
+
cfg.text_cleaning.lowercase or cfg.text_cleaning.remove_special_chars or
|
| 120 |
+
cfg.text_cleaning.strip_extra_linebreaks
|
| 121 |
+
)
|
| 122 |
+
if any_cleaning:
|
| 123 |
+
df = apply_text_cleaning(df, text_cols, cfg.text_cleaning)
|
| 124 |
+
self._log("Text Cleaning", "Applied text cleaning filters", before, len(df), time.time() - t0)
|
| 125 |
+
_progress("Text Cleaning")
|
| 126 |
+
|
| 127 |
+
# ── Stage 2: Quality Filters ──
|
| 128 |
+
t0 = time.time()
|
| 129 |
+
before = len(df)
|
| 130 |
+
has_quality = (
|
| 131 |
+
cfg.quality_filters.min_word_count > 0 or
|
| 132 |
+
cfg.quality_filters.max_word_count > 0 or
|
| 133 |
+
cfg.quality_filters.profanity_filter or
|
| 134 |
+
cfg.quality_filters.language_filter or
|
| 135 |
+
cfg.quality_filters.remove_low_quality
|
| 136 |
+
)
|
| 137 |
+
if has_quality and cfg.output_col:
|
| 138 |
+
df = apply_quality_filters(df, cfg.output_col, cfg.quality_filters)
|
| 139 |
+
self._log("Quality Filters", "Applied quality filters", before, len(df), time.time() - t0)
|
| 140 |
+
_progress("Quality Filters")
|
| 141 |
+
|
| 142 |
+
# ── Stage 3: Deduplication ──
|
| 143 |
+
t0 = time.time()
|
| 144 |
+
before = len(df)
|
| 145 |
+
if cfg.instruction_col and (cfg.deduplication.remove_exact or cfg.deduplication.remove_semantic):
|
| 146 |
+
df = apply_deduplication(df, cfg.instruction_col, cfg.deduplication)
|
| 147 |
+
self._log("Deduplication", "Removed duplicate samples", before, len(df), time.time() - t0)
|
| 148 |
+
_progress("Deduplication")
|
| 149 |
+
|
| 150 |
+
# ── Stage 4: PII Filtering ──
|
| 151 |
+
t0 = time.time()
|
| 152 |
+
before = len(df)
|
| 153 |
+
has_pii = (
|
| 154 |
+
cfg.pii_filter.filter_emails or cfg.pii_filter.filter_phones or
|
| 155 |
+
cfg.pii_filter.filter_id_numbers or cfg.pii_filter.filter_api_keys or
|
| 156 |
+
cfg.pii_filter.filter_addresses
|
| 157 |
+
)
|
| 158 |
+
if has_pii:
|
| 159 |
+
df = apply_pii_filter_df(df, text_cols, cfg.pii_filter)
|
| 160 |
+
self._log("PII Filtering", "Masked PII data", before, len(df), time.time() - t0)
|
| 161 |
+
_progress("PII Filtering")
|
| 162 |
+
|
| 163 |
+
# ── Stage 5: Dataset Balancing ──
|
| 164 |
+
t0 = time.time()
|
| 165 |
+
before = len(df)
|
| 166 |
+
if cfg.balancing.enabled and cfg.balancing.label_column and cfg.balancing.strategy != "none":
|
| 167 |
+
df = balance_dataset(df, cfg.balancing.label_column, cfg.balancing.strategy)
|
| 168 |
+
self._log("Balancing", "Balanced dataset classes", before, len(df), time.time() - t0)
|
| 169 |
+
_progress("Balancing")
|
| 170 |
+
|
| 171 |
+
# ── Stage 6: Augmentation ──
|
| 172 |
+
t0 = time.time()
|
| 173 |
+
before = len(df)
|
| 174 |
+
if cfg.augmentation.enabled and cfg.instruction_col:
|
| 175 |
+
df = augment_dataset(df, cfg.instruction_col, cfg.augmentation)
|
| 176 |
+
self._log("Augmentation", "Generated augmented samples", before, len(df), time.time() - t0)
|
| 177 |
+
_progress("Augmentation")
|
| 178 |
+
|
| 179 |
+
# ── Stage 7: Tokenization Controls ──
|
| 180 |
+
t0 = time.time()
|
| 181 |
+
before = len(df)
|
| 182 |
+
if cfg.tokenization.truncate_long or cfg.tokenization.split_long:
|
| 183 |
+
try:
|
| 184 |
+
tokenizer = get_tokenizer(cfg.tokenization)
|
| 185 |
+
is_tiktoken = cfg.tokenization.tokenizer_name == "tiktoken"
|
| 186 |
+
|
| 187 |
+
for col in text_cols:
|
| 188 |
+
if cfg.tokenization.split_long:
|
| 189 |
+
df = split_long_samples(
|
| 190 |
+
df, col, cfg.tokenization.max_total_tokens,
|
| 191 |
+
tokenizer, is_tiktoken, cfg.tokenization.split_overlap,
|
| 192 |
+
)
|
| 193 |
+
elif cfg.tokenization.truncate_long:
|
| 194 |
+
df = truncate_samples(
|
| 195 |
+
df, col, cfg.tokenization.max_total_tokens,
|
| 196 |
+
tokenizer, is_tiktoken,
|
| 197 |
+
)
|
| 198 |
+
except ImportError:
|
| 199 |
+
pass # tokenizer not available
|
| 200 |
+
self._log("Tokenization", "Applied tokenization controls", before, len(df), time.time() - t0)
|
| 201 |
+
_progress("Tokenization")
|
| 202 |
+
|
| 203 |
+
# ── Split ──
|
| 204 |
+
train_df, val_df = split_dataset(df, cfg.split)
|
| 205 |
+
|
| 206 |
+
return train_df, val_df, self.logs
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_safe_preset() -> PreprocessingConfig:
|
| 210 |
+
"""Return a sensible 'safe preset' configuration for common use cases."""
|
| 211 |
+
return PreprocessingConfig(
|
| 212 |
+
text_cleaning=TextCleaningConfig(
|
| 213 |
+
remove_html=True,
|
| 214 |
+
remove_urls=True,
|
| 215 |
+
remove_emojis=False,
|
| 216 |
+
normalize_whitespace=True,
|
| 217 |
+
lowercase=False,
|
| 218 |
+
remove_special_chars=False,
|
| 219 |
+
strip_extra_linebreaks=True,
|
| 220 |
+
),
|
| 221 |
+
quality_filters=QualityFilterConfig(
|
| 222 |
+
min_word_count=3,
|
| 223 |
+
max_word_count=0,
|
| 224 |
+
profanity_filter=False,
|
| 225 |
+
language_filter=False,
|
| 226 |
+
remove_low_quality=True,
|
| 227 |
+
min_quality_length=20,
|
| 228 |
+
),
|
| 229 |
+
deduplication=DeduplicationConfig(
|
| 230 |
+
remove_exact=True,
|
| 231 |
+
remove_semantic=False,
|
| 232 |
+
),
|
| 233 |
+
pii_filter=PIIFilterConfig(
|
| 234 |
+
filter_emails=True,
|
| 235 |
+
filter_phones=True,
|
| 236 |
+
filter_id_numbers=True,
|
| 237 |
+
filter_api_keys=True,
|
| 238 |
+
filter_addresses=False,
|
| 239 |
+
),
|
| 240 |
+
split=SplitConfig(
|
| 241 |
+
enabled=True,
|
| 242 |
+
train_ratio=0.9,
|
| 243 |
+
random_seed=42,
|
| 244 |
+
shuffle=True,
|
| 245 |
+
),
|
| 246 |
+
output_format=OutputFormatConfig(
|
| 247 |
+
format_type="openai_chat",
|
| 248 |
+
),
|
| 249 |
+
system_prompt=SystemPromptConfig(
|
| 250 |
+
system_prompt="You are a helpful AI assistant.",
|
| 251 |
+
prepend_to_all=True,
|
| 252 |
+
),
|
| 253 |
+
)
|
preprocessing/quality_filters.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quality Filters Module
|
| 3 |
+
========================
|
| 4 |
+
Filter samples by word count, profanity, language,
|
| 5 |
+
and low-quality response detection.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
import re
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class QualityFilterConfig:
|
| 16 |
+
"""Configuration for quality filters."""
|
| 17 |
+
min_word_count: int = 0
|
| 18 |
+
max_word_count: int = 0 # 0 = no limit
|
| 19 |
+
profanity_filter: bool = False
|
| 20 |
+
language_filter: bool = False
|
| 21 |
+
allowed_languages: List[str] = field(default_factory=lambda: ["en"])
|
| 22 |
+
remove_low_quality: bool = False
|
| 23 |
+
min_quality_length: int = 20
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Profanity word list (small built-in set, extend as needed)
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
_PROFANITY_WORDS = {
|
| 30 |
+
'fuck', 'shit', 'damn', 'ass', 'bitch', 'bastard', 'crap',
|
| 31 |
+
'dick', 'piss', 'slut', 'whore', 'cock',
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
# Generic filler/placeholder responses that indicate low quality
|
| 35 |
+
_GENERIC_RESPONSES = [
|
| 36 |
+
"i don't know",
|
| 37 |
+
"i am not sure",
|
| 38 |
+
"no comment",
|
| 39 |
+
"n/a",
|
| 40 |
+
"none",
|
| 41 |
+
"null",
|
| 42 |
+
"test",
|
| 43 |
+
"asdf",
|
| 44 |
+
"lorem ipsum",
|
| 45 |
+
"placeholder",
|
| 46 |
+
"todo",
|
| 47 |
+
"tbd",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _word_count(text: str) -> int:
|
| 52 |
+
"""Count words in a text string."""
|
| 53 |
+
if not isinstance(text, str):
|
| 54 |
+
return 0
|
| 55 |
+
return len(text.split())
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def filter_by_word_count(
|
| 59 |
+
df: pd.DataFrame,
|
| 60 |
+
col: str,
|
| 61 |
+
min_words: int = 0,
|
| 62 |
+
max_words: int = 0,
|
| 63 |
+
) -> pd.DataFrame:
|
| 64 |
+
"""Filter rows by word count in the given column."""
|
| 65 |
+
df = df.copy()
|
| 66 |
+
counts = df[col].apply(_word_count)
|
| 67 |
+
|
| 68 |
+
if min_words > 0:
|
| 69 |
+
df = df[counts >= min_words]
|
| 70 |
+
counts = counts[df.index]
|
| 71 |
+
|
| 72 |
+
if max_words > 0:
|
| 73 |
+
df = df[counts <= max_words]
|
| 74 |
+
|
| 75 |
+
return df.reset_index(drop=True)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def contains_profanity(text: str) -> bool:
|
| 79 |
+
"""Check if text contains any profanity words."""
|
| 80 |
+
if not isinstance(text, str):
|
| 81 |
+
return False
|
| 82 |
+
words = set(re.findall(r'\b\w+\b', text.lower()))
|
| 83 |
+
return bool(words & _PROFANITY_WORDS)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def filter_profanity(
|
| 87 |
+
df: pd.DataFrame,
|
| 88 |
+
col: str,
|
| 89 |
+
) -> pd.DataFrame:
|
| 90 |
+
"""Remove rows containing profanity in the given column."""
|
| 91 |
+
mask = ~df[col].apply(contains_profanity)
|
| 92 |
+
return df[mask].reset_index(drop=True)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def detect_language(text: str) -> str:
|
| 96 |
+
"""
|
| 97 |
+
Detect the language of a text string.
|
| 98 |
+
Returns ISO 639-1 code (e.g., 'en', 'fr', 'de').
|
| 99 |
+
Falls back to 'unknown' if detection fails.
|
| 100 |
+
"""
|
| 101 |
+
try:
|
| 102 |
+
from langdetect import detect
|
| 103 |
+
if not isinstance(text, str) or len(text.strip()) < 10:
|
| 104 |
+
return 'unknown'
|
| 105 |
+
return detect(text)
|
| 106 |
+
except ImportError:
|
| 107 |
+
return 'unknown'
|
| 108 |
+
except Exception:
|
| 109 |
+
return 'unknown'
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def filter_by_language(
|
| 113 |
+
df: pd.DataFrame,
|
| 114 |
+
col: str,
|
| 115 |
+
allowed_langs: List[str] = None,
|
| 116 |
+
) -> pd.DataFrame:
|
| 117 |
+
"""Keep only rows where the text is in one of the allowed languages."""
|
| 118 |
+
if allowed_langs is None:
|
| 119 |
+
allowed_langs = ['en']
|
| 120 |
+
|
| 121 |
+
langs = df[col].apply(detect_language)
|
| 122 |
+
mask = langs.isin(allowed_langs) | (langs == 'unknown')
|
| 123 |
+
return df[mask].reset_index(drop=True)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def is_low_quality(text: str, min_len: int = 20) -> bool:
|
| 127 |
+
"""
|
| 128 |
+
Check if a response is low-quality:
|
| 129 |
+
- Too short
|
| 130 |
+
- Matches generic/placeholder patterns
|
| 131 |
+
"""
|
| 132 |
+
if not isinstance(text, str):
|
| 133 |
+
return True
|
| 134 |
+
text_stripped = text.strip()
|
| 135 |
+
if len(text_stripped) < min_len:
|
| 136 |
+
return True
|
| 137 |
+
text_lower = text_stripped.lower()
|
| 138 |
+
for phrase in _GENERIC_RESPONSES:
|
| 139 |
+
if text_lower == phrase or text_lower.startswith(phrase):
|
| 140 |
+
return True
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def filter_low_quality(
|
| 145 |
+
df: pd.DataFrame,
|
| 146 |
+
col: str,
|
| 147 |
+
min_len: int = 20,
|
| 148 |
+
) -> pd.DataFrame:
|
| 149 |
+
"""Remove low-quality responses."""
|
| 150 |
+
mask = ~df[col].apply(lambda t: is_low_quality(t, min_len))
|
| 151 |
+
return df[mask].reset_index(drop=True)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def apply_quality_filters(
|
| 155 |
+
df: pd.DataFrame,
|
| 156 |
+
col: str,
|
| 157 |
+
config: QualityFilterConfig,
|
| 158 |
+
) -> pd.DataFrame:
|
| 159 |
+
"""Apply all enabled quality filters to a DataFrame."""
|
| 160 |
+
if config.min_word_count > 0 or config.max_word_count > 0:
|
| 161 |
+
df = filter_by_word_count(df, col, config.min_word_count, config.max_word_count)
|
| 162 |
+
|
| 163 |
+
if config.profanity_filter:
|
| 164 |
+
df = filter_profanity(df, col)
|
| 165 |
+
|
| 166 |
+
if config.language_filter:
|
| 167 |
+
df = filter_by_language(df, col, config.allowed_languages)
|
| 168 |
+
|
| 169 |
+
if config.remove_low_quality:
|
| 170 |
+
df = filter_low_quality(df, col, config.min_quality_length)
|
| 171 |
+
|
| 172 |
+
return df
|
preprocessing/system_prompt.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
System Prompt Configuration Module
|
| 3 |
+
=====================================
|
| 4 |
+
Manage global system prompts, prepend to samples,
|
| 5 |
+
and preview formatted chat JSON.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
import json
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class SystemPromptConfig:
|
| 16 |
+
"""Configuration for system prompt handling."""
|
| 17 |
+
system_prompt: str = "You are a helpful AI assistant."
|
| 18 |
+
prepend_to_all: bool = True
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_chat_json(
|
| 22 |
+
instruction: str,
|
| 23 |
+
output: str,
|
| 24 |
+
system_prompt: str = "",
|
| 25 |
+
context: str = "",
|
| 26 |
+
) -> Dict[str, Any]:
|
| 27 |
+
"""
|
| 28 |
+
Build a single chat-format JSON entry.
|
| 29 |
+
Returns {"messages": [{"role": ..., "content": ...}, ...]}.
|
| 30 |
+
"""
|
| 31 |
+
messages = []
|
| 32 |
+
|
| 33 |
+
if system_prompt:
|
| 34 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 35 |
+
|
| 36 |
+
user_content = instruction
|
| 37 |
+
if context:
|
| 38 |
+
user_content += f"\n\nContext: {context}"
|
| 39 |
+
|
| 40 |
+
messages.append({"role": "user", "content": user_content})
|
| 41 |
+
messages.append({"role": "assistant", "content": output})
|
| 42 |
+
|
| 43 |
+
return {"messages": messages}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def preview_formatted(
|
| 47 |
+
df: pd.DataFrame,
|
| 48 |
+
system_prompt: str,
|
| 49 |
+
instruction_col: str,
|
| 50 |
+
output_col: str,
|
| 51 |
+
input_col: Optional[str] = None,
|
| 52 |
+
n: int = 3,
|
| 53 |
+
) -> List[Dict[str, Any]]:
|
| 54 |
+
"""
|
| 55 |
+
Generate a preview of n formatted chat-JSON samples.
|
| 56 |
+
"""
|
| 57 |
+
previews = []
|
| 58 |
+
for i, (_, row) in enumerate(df.head(n).iterrows()):
|
| 59 |
+
instruction = str(row.get(instruction_col, ''))
|
| 60 |
+
output = str(row.get(output_col, ''))
|
| 61 |
+
context = str(row.get(input_col, '')) if input_col and input_col in df.columns else ''
|
| 62 |
+
previews.append(
|
| 63 |
+
build_chat_json(instruction, output, system_prompt, context)
|
| 64 |
+
)
|
| 65 |
+
return previews
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def preview_formatted_json(
|
| 69 |
+
df: pd.DataFrame,
|
| 70 |
+
system_prompt: str,
|
| 71 |
+
instruction_col: str,
|
| 72 |
+
output_col: str,
|
| 73 |
+
input_col: Optional[str] = None,
|
| 74 |
+
n: int = 3,
|
| 75 |
+
) -> str:
|
| 76 |
+
"""Return a pretty-printed JSON string of n sample entries."""
|
| 77 |
+
samples = preview_formatted(
|
| 78 |
+
df, system_prompt, instruction_col, output_col, input_col, n
|
| 79 |
+
)
|
| 80 |
+
return json.dumps(samples, indent=2, ensure_ascii=False)
|
preprocessing/text_cleaning.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text Cleaning Module
|
| 3 |
+
=====================
|
| 4 |
+
Pure functions for text preprocessing toggles.
|
| 5 |
+
Each function operates on a single string and can be
|
| 6 |
+
composed via apply_text_cleaning().
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
import unicodedata
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import List
|
| 13 |
+
import pandas as pd
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class TextCleaningConfig:
|
| 18 |
+
"""Configuration for text cleaning options."""
|
| 19 |
+
remove_html: bool = False
|
| 20 |
+
remove_urls: bool = False
|
| 21 |
+
remove_emojis: bool = False
|
| 22 |
+
normalize_whitespace: bool = True
|
| 23 |
+
lowercase: bool = False
|
| 24 |
+
remove_special_chars: bool = False
|
| 25 |
+
strip_extra_linebreaks: bool = True
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Individual cleaning functions
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
def remove_html_tags(text: str) -> str:
|
| 33 |
+
"""Strip all HTML tags from text."""
|
| 34 |
+
return re.sub(r'<[^>]+>', '', text)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def remove_urls(text: str) -> str:
|
| 38 |
+
"""Remove URLs (http, https, ftp, www) from text."""
|
| 39 |
+
return re.sub(
|
| 40 |
+
r'https?://\S+|ftp://\S+|www\.\S+',
|
| 41 |
+
'', text
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
_EMOJI_PATTERN = re.compile(
|
| 46 |
+
"["
|
| 47 |
+
"\U0001F600-\U0001F64F" # emoticons
|
| 48 |
+
"\U0001F300-\U0001F5FF" # symbols & pictographs
|
| 49 |
+
"\U0001F680-\U0001F6FF" # transport & map symbols
|
| 50 |
+
"\U0001F1E0-\U0001F1FF" # flags
|
| 51 |
+
"\U00002702-\U000027B0"
|
| 52 |
+
"\U000024C2-\U0001F251"
|
| 53 |
+
"\U0001F900-\U0001F9FF" # supplemental symbols
|
| 54 |
+
"\U0001FA00-\U0001FA6F"
|
| 55 |
+
"\U0001FA70-\U0001FAFF"
|
| 56 |
+
"\U00002702-\U000027B0"
|
| 57 |
+
"]+",
|
| 58 |
+
flags=re.UNICODE,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def remove_emojis(text: str) -> str:
|
| 63 |
+
"""Remove emoji characters from text."""
|
| 64 |
+
return _EMOJI_PATTERN.sub('', text)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def normalize_whitespace(text: str) -> str:
|
| 68 |
+
"""Collapse multiple spaces/tabs into a single space."""
|
| 69 |
+
return re.sub(r'[^\S\n]+', ' ', text).strip()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def to_lowercase(text: str) -> str:
|
| 73 |
+
"""Convert text to lowercase."""
|
| 74 |
+
return text.lower()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def remove_special_characters(text: str) -> str:
|
| 78 |
+
"""Keep only alphanumeric, basic punctuation, and whitespace."""
|
| 79 |
+
return re.sub(r'[^a-zA-Z0-9\s.,!?;:\'"()\-\n]', '', text)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def strip_extra_linebreaks(text: str) -> str:
|
| 83 |
+
"""Reduce three or more consecutive newlines to two."""
|
| 84 |
+
return re.sub(r'\n{3,}', '\n\n', text)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# Composed cleaner
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
def clean_text(text: str, config: TextCleaningConfig) -> str:
|
| 92 |
+
"""Apply all enabled cleaning steps to a single text string."""
|
| 93 |
+
if not isinstance(text, str):
|
| 94 |
+
return str(text) if text else ''
|
| 95 |
+
|
| 96 |
+
if config.remove_html:
|
| 97 |
+
text = remove_html_tags(text)
|
| 98 |
+
if config.remove_urls:
|
| 99 |
+
text = remove_urls(text)
|
| 100 |
+
if config.remove_emojis:
|
| 101 |
+
text = remove_emojis(text)
|
| 102 |
+
if config.remove_special_chars:
|
| 103 |
+
text = remove_special_characters(text)
|
| 104 |
+
if config.lowercase:
|
| 105 |
+
text = to_lowercase(text)
|
| 106 |
+
if config.normalize_whitespace:
|
| 107 |
+
text = normalize_whitespace(text)
|
| 108 |
+
if config.strip_extra_linebreaks:
|
| 109 |
+
text = strip_extra_linebreaks(text)
|
| 110 |
+
|
| 111 |
+
return text
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def apply_text_cleaning(
|
| 115 |
+
df: pd.DataFrame,
|
| 116 |
+
columns: List[str],
|
| 117 |
+
config: TextCleaningConfig,
|
| 118 |
+
) -> pd.DataFrame:
|
| 119 |
+
"""Apply text cleaning to specified columns of a DataFrame."""
|
| 120 |
+
df = df.copy()
|
| 121 |
+
for col in columns:
|
| 122 |
+
if col in df.columns:
|
| 123 |
+
df[col] = df[col].apply(lambda t: clean_text(t, config))
|
| 124 |
+
return df
|
preprocessing/tokenization.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tokenization Controls Module
|
| 3 |
+
==============================
|
| 4 |
+
Tokenizer selection, token counting, truncation, and splitting.
|
| 5 |
+
Supports tiktoken (OpenAI) and HuggingFace tokenizers.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Dict, List, Any, Optional
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TokenizationConfig:
|
| 16 |
+
"""Configuration for tokenization controls."""
|
| 17 |
+
tokenizer_name: str = "tiktoken" # "tiktoken" or HF model name
|
| 18 |
+
tiktoken_encoding: str = "cl100k_base" # for tiktoken
|
| 19 |
+
max_total_tokens: int = 2048
|
| 20 |
+
truncate_long: bool = False
|
| 21 |
+
split_long: bool = False
|
| 22 |
+
split_overlap: int = 50 # overlap tokens when splitting
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_tokenizer(config: TokenizationConfig):
|
| 26 |
+
"""
|
| 27 |
+
Return a tokenizer-like object.
|
| 28 |
+
For tiktoken: returns the encoding object.
|
| 29 |
+
For HF: returns AutoTokenizer instance.
|
| 30 |
+
"""
|
| 31 |
+
if config.tokenizer_name == "tiktoken":
|
| 32 |
+
try:
|
| 33 |
+
import tiktoken
|
| 34 |
+
return tiktoken.get_encoding(config.tiktoken_encoding)
|
| 35 |
+
except ImportError:
|
| 36 |
+
raise ImportError("tiktoken is required. Install with: pip install tiktoken")
|
| 37 |
+
else:
|
| 38 |
+
try:
|
| 39 |
+
from transformers import AutoTokenizer
|
| 40 |
+
return AutoTokenizer.from_pretrained(config.tokenizer_name)
|
| 41 |
+
except ImportError:
|
| 42 |
+
raise ImportError("transformers is required for HF tokenizers.")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def count_tokens(text: str, tokenizer, is_tiktoken: bool = True) -> int:
|
| 46 |
+
"""Count tokens in a text string."""
|
| 47 |
+
if not isinstance(text, str) or not text.strip():
|
| 48 |
+
return 0
|
| 49 |
+
if is_tiktoken:
|
| 50 |
+
return len(tokenizer.encode(text))
|
| 51 |
+
else:
|
| 52 |
+
return len(tokenizer.encode(text, add_special_tokens=False))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def compute_token_stats(
|
| 56 |
+
df: pd.DataFrame,
|
| 57 |
+
columns: List[str],
|
| 58 |
+
tokenizer,
|
| 59 |
+
is_tiktoken: bool = True,
|
| 60 |
+
) -> Dict[str, Dict[str, float]]:
|
| 61 |
+
"""
|
| 62 |
+
Compute token statistics for specified columns.
|
| 63 |
+
Returns dict of column -> {min, max, mean, median, p95, total}.
|
| 64 |
+
"""
|
| 65 |
+
stats = {}
|
| 66 |
+
for col in columns:
|
| 67 |
+
if col not in df.columns:
|
| 68 |
+
continue
|
| 69 |
+
counts = df[col].apply(lambda t: count_tokens(t, tokenizer, is_tiktoken))
|
| 70 |
+
stats[col] = {
|
| 71 |
+
'min': int(counts.min()) if len(counts) > 0 else 0,
|
| 72 |
+
'max': int(counts.max()) if len(counts) > 0 else 0,
|
| 73 |
+
'mean': round(float(counts.mean()), 1) if len(counts) > 0 else 0,
|
| 74 |
+
'median': int(counts.median()) if len(counts) > 0 else 0,
|
| 75 |
+
'p95': int(np.percentile(counts, 95)) if len(counts) > 0 else 0,
|
| 76 |
+
'total': int(counts.sum()),
|
| 77 |
+
}
|
| 78 |
+
return stats
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def truncate_samples(
|
| 82 |
+
df: pd.DataFrame,
|
| 83 |
+
col: str,
|
| 84 |
+
max_tokens: int,
|
| 85 |
+
tokenizer,
|
| 86 |
+
is_tiktoken: bool = True,
|
| 87 |
+
) -> pd.DataFrame:
|
| 88 |
+
"""Truncate text in a column to max_tokens."""
|
| 89 |
+
df = df.copy()
|
| 90 |
+
|
| 91 |
+
def _truncate(text):
|
| 92 |
+
if not isinstance(text, str):
|
| 93 |
+
return text
|
| 94 |
+
if is_tiktoken:
|
| 95 |
+
tokens = tokenizer.encode(text)
|
| 96 |
+
if len(tokens) > max_tokens:
|
| 97 |
+
return tokenizer.decode(tokens[:max_tokens])
|
| 98 |
+
else:
|
| 99 |
+
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 100 |
+
if len(tokens) > max_tokens:
|
| 101 |
+
return tokenizer.decode(tokens[:max_tokens])
|
| 102 |
+
return text
|
| 103 |
+
|
| 104 |
+
df[col] = df[col].apply(_truncate)
|
| 105 |
+
return df
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def split_long_samples(
|
| 109 |
+
df: pd.DataFrame,
|
| 110 |
+
col: str,
|
| 111 |
+
max_tokens: int,
|
| 112 |
+
tokenizer,
|
| 113 |
+
is_tiktoken: bool = True,
|
| 114 |
+
overlap: int = 50,
|
| 115 |
+
) -> pd.DataFrame:
|
| 116 |
+
"""
|
| 117 |
+
Split rows whose text exceeds max_tokens into multiple rows.
|
| 118 |
+
Each chunk has `overlap` tokens of context from the previous chunk.
|
| 119 |
+
"""
|
| 120 |
+
new_rows = []
|
| 121 |
+
for _, row in df.iterrows():
|
| 122 |
+
text = row[col]
|
| 123 |
+
if not isinstance(text, str):
|
| 124 |
+
new_rows.append(row)
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
if is_tiktoken:
|
| 128 |
+
tokens = tokenizer.encode(text)
|
| 129 |
+
else:
|
| 130 |
+
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 131 |
+
|
| 132 |
+
if len(tokens) <= max_tokens:
|
| 133 |
+
new_rows.append(row)
|
| 134 |
+
else:
|
| 135 |
+
step = max(1, max_tokens - overlap)
|
| 136 |
+
for i in range(0, len(tokens), step):
|
| 137 |
+
chunk_tokens = tokens[i:i + max_tokens]
|
| 138 |
+
if not chunk_tokens:
|
| 139 |
+
break
|
| 140 |
+
new_row = row.copy()
|
| 141 |
+
if is_tiktoken:
|
| 142 |
+
new_row[col] = tokenizer.decode(chunk_tokens)
|
| 143 |
+
else:
|
| 144 |
+
new_row[col] = tokenizer.decode(chunk_tokens)
|
| 145 |
+
new_rows.append(new_row)
|
| 146 |
+
|
| 147 |
+
return pd.DataFrame(new_rows).reset_index(drop=True)
|
preprocessing/train_val_split.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train / Validation Split Module
|
| 3 |
+
==================================
|
| 4 |
+
Split datasets with configurable ratio, seed, and shuffle.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class SplitConfig:
|
| 14 |
+
"""Configuration for train/validation split."""
|
| 15 |
+
enabled: bool = True
|
| 16 |
+
train_ratio: float = 0.8 # e.g., 0.8 means 80% train, 20% val
|
| 17 |
+
random_seed: int = 42
|
| 18 |
+
shuffle: bool = True
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def split_dataset(
|
| 22 |
+
df: pd.DataFrame,
|
| 23 |
+
config: SplitConfig,
|
| 24 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 25 |
+
"""
|
| 26 |
+
Split DataFrame into train and validation sets.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
(train_df, val_df) tuple
|
| 30 |
+
"""
|
| 31 |
+
if not config.enabled:
|
| 32 |
+
return df, pd.DataFrame(columns=df.columns)
|
| 33 |
+
|
| 34 |
+
if config.shuffle:
|
| 35 |
+
df = df.sample(frac=1, random_state=config.random_seed).reset_index(drop=True)
|
| 36 |
+
|
| 37 |
+
split_idx = int(len(df) * config.train_ratio)
|
| 38 |
+
train_df = df.iloc[:split_idx].reset_index(drop=True)
|
| 39 |
+
val_df = df.iloc[split_idx:].reset_index(drop=True)
|
| 40 |
+
|
| 41 |
+
return train_df, val_df
|
requirements.txt
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto-FineTune-Ops Dependencies
|
| 2 |
+
# Core ML Libraries
|
| 3 |
+
unsloth @ git+https://github.com/unslothai/unsloth.git
|
| 4 |
+
trl>=0.7.0
|
| 5 |
+
peft>=0.7.0
|
| 6 |
+
transformers>=4.36.0
|
| 7 |
+
datasets>=2.14.0
|
| 8 |
+
accelerate>=0.25.0
|
| 9 |
+
bitsandbytes>=0.41.0
|
| 10 |
+
|
| 11 |
+
# Data Processing
|
| 12 |
+
pandas>=2.0.0
|
| 13 |
+
numpy>=1.24.0
|
| 14 |
+
|
| 15 |
+
# Advanced Preprocessing
|
| 16 |
+
tiktoken>=0.5.0
|
| 17 |
+
langdetect>=1.0.9
|
| 18 |
+
scikit-learn>=1.3.0
|
| 19 |
+
|
| 20 |
+
# API & Deployment
|
| 21 |
+
fastapi>=0.104.0
|
| 22 |
+
uvicorn>=0.24.0
|
| 23 |
+
python-multipart>=0.0.6
|
| 24 |
+
|
| 25 |
+
# LLM Judge Clients
|
| 26 |
+
openai>=1.0.0
|
| 27 |
+
anthropic>=0.8.0
|
| 28 |
+
|
| 29 |
+
# Utilities
|
| 30 |
+
pyyaml>=6.0
|
| 31 |
+
python-dotenv>=1.0.0
|
| 32 |
+
rich>=13.0.0
|
| 33 |
+
typer>=0.9.0
|
| 34 |
+
tqdm>=4.66.0
|
| 35 |
+
|
| 36 |
+
# Dashboard
|
| 37 |
+
streamlit>=1.32.0
|
| 38 |
+
plotly>=5.18.0
|
| 39 |
+
|
| 40 |
+
# CUDA/Torch (install separately based on your CUDA version)
|
| 41 |
+
# torch>=2.1.0
|
| 42 |
+
# xformers>=0.0.23
|
scripts/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auto-FineTune-Ops Scripts Package"""
|
| 2 |
+
|
| 3 |
+
from .deploy import DeploymentServer
|
| 4 |
+
|
| 5 |
+
__all__ = ["DeploymentServer"]
|
scripts/deploy.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Deployment Server
|
| 3 |
+
==========================
|
| 4 |
+
One-click deployment bridge for fine-tuned models.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional, List, Dict, Any
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
from rich.console import Console
|
| 14 |
+
|
| 15 |
+
console = Console()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class GenerationRequest:
|
| 20 |
+
"""Request model for text generation."""
|
| 21 |
+
prompt: str
|
| 22 |
+
system_prompt: Optional[str] = None
|
| 23 |
+
max_tokens: int = 512
|
| 24 |
+
temperature: float = 0.7
|
| 25 |
+
top_p: float = 0.9
|
| 26 |
+
stream: bool = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class GenerationResponse:
|
| 31 |
+
"""Response model for text generation."""
|
| 32 |
+
generated_text: str
|
| 33 |
+
prompt: str
|
| 34 |
+
model: str
|
| 35 |
+
tokens_generated: int
|
| 36 |
+
generation_time: float
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class DeploymentServer:
|
| 40 |
+
"""
|
| 41 |
+
FastAPI-based deployment server for fine-tuned models.
|
| 42 |
+
|
| 43 |
+
Features:
|
| 44 |
+
- RESTful API for inference
|
| 45 |
+
- Health check endpoint
|
| 46 |
+
- Batch generation support
|
| 47 |
+
- Automatic model loading
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
model_path: str,
|
| 53 |
+
host: str = "0.0.0.0",
|
| 54 |
+
port: int = 8000,
|
| 55 |
+
max_seq_length: int = 2048
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Initialize the deployment server.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
model_path: Path to the fine-tuned model
|
| 62 |
+
host: Server host
|
| 63 |
+
port: Server port
|
| 64 |
+
max_seq_length: Maximum sequence length
|
| 65 |
+
"""
|
| 66 |
+
self.model_path = model_path
|
| 67 |
+
self.host = host
|
| 68 |
+
self.port = port
|
| 69 |
+
self.max_seq_length = max_seq_length
|
| 70 |
+
|
| 71 |
+
self.model = None
|
| 72 |
+
self.tokenizer = None
|
| 73 |
+
self.app = None
|
| 74 |
+
|
| 75 |
+
def load_model(self):
|
| 76 |
+
"""Load the fine-tuned model."""
|
| 77 |
+
console.print(f"\n[bold blue]📂 Loading model from:[/] {self.model_path}")
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
from unsloth import FastLanguageModel
|
| 81 |
+
|
| 82 |
+
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
| 83 |
+
model_name=self.model_path,
|
| 84 |
+
max_seq_length=self.max_seq_length,
|
| 85 |
+
dtype=None,
|
| 86 |
+
load_in_4bit=True,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
FastLanguageModel.for_inference(self.model)
|
| 90 |
+
|
| 91 |
+
console.print("[green]✓ Model loaded successfully[/]")
|
| 92 |
+
|
| 93 |
+
except ImportError:
|
| 94 |
+
console.print("[yellow]⚠️ Unsloth not available, trying transformers...[/]")
|
| 95 |
+
|
| 96 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 97 |
+
|
| 98 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
| 99 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 100 |
+
self.model_path,
|
| 101 |
+
device_map="auto",
|
| 102 |
+
torch_dtype="auto"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
console.print("[green]✓ Model loaded with transformers[/]")
|
| 106 |
+
|
| 107 |
+
def generate(
|
| 108 |
+
self,
|
| 109 |
+
prompt: str,
|
| 110 |
+
system_prompt: Optional[str] = None,
|
| 111 |
+
max_tokens: int = 512,
|
| 112 |
+
temperature: float = 0.7,
|
| 113 |
+
top_p: float = 0.9
|
| 114 |
+
) -> GenerationResponse:
|
| 115 |
+
"""
|
| 116 |
+
Generate text from the model.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
prompt: User prompt
|
| 120 |
+
system_prompt: Optional system prompt
|
| 121 |
+
max_tokens: Maximum tokens to generate
|
| 122 |
+
temperature: Sampling temperature
|
| 123 |
+
top_p: Top-p sampling parameter
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
GenerationResponse with generated text
|
| 127 |
+
"""
|
| 128 |
+
if self.model is None:
|
| 129 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
| 130 |
+
|
| 131 |
+
start_time = datetime.now()
|
| 132 |
+
|
| 133 |
+
# Format prompt with Alpaca template
|
| 134 |
+
if system_prompt:
|
| 135 |
+
formatted_prompt = f"""{system_prompt}
|
| 136 |
+
|
| 137 |
+
### Instruction:
|
| 138 |
+
{prompt}
|
| 139 |
+
|
| 140 |
+
### Response:
|
| 141 |
+
"""
|
| 142 |
+
else:
|
| 143 |
+
formatted_prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| 144 |
+
|
| 145 |
+
### Instruction:
|
| 146 |
+
{prompt}
|
| 147 |
+
|
| 148 |
+
### Response:
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
# Tokenize
|
| 152 |
+
inputs = self.tokenizer(
|
| 153 |
+
formatted_prompt,
|
| 154 |
+
return_tensors="pt"
|
| 155 |
+
).to(self.model.device)
|
| 156 |
+
|
| 157 |
+
# Generate
|
| 158 |
+
outputs = self.model.generate(
|
| 159 |
+
**inputs,
|
| 160 |
+
max_new_tokens=max_tokens,
|
| 161 |
+
temperature=temperature,
|
| 162 |
+
top_p=top_p,
|
| 163 |
+
do_sample=True,
|
| 164 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Decode
|
| 168 |
+
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 169 |
+
|
| 170 |
+
# Extract just the generated part
|
| 171 |
+
if "### Response:" in full_response:
|
| 172 |
+
generated_text = full_response.split("### Response:")[-1].strip()
|
| 173 |
+
else:
|
| 174 |
+
generated_text = full_response[len(formatted_prompt):].strip()
|
| 175 |
+
|
| 176 |
+
generation_time = (datetime.now() - start_time).total_seconds()
|
| 177 |
+
tokens_generated = len(self.tokenizer.encode(generated_text))
|
| 178 |
+
|
| 179 |
+
return GenerationResponse(
|
| 180 |
+
generated_text=generated_text,
|
| 181 |
+
prompt=prompt,
|
| 182 |
+
model=self.model_path,
|
| 183 |
+
tokens_generated=tokens_generated,
|
| 184 |
+
generation_time=generation_time
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def create_app(self):
|
| 188 |
+
"""Create the FastAPI application."""
|
| 189 |
+
from fastapi import FastAPI, HTTPException
|
| 190 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 191 |
+
from pydantic import BaseModel
|
| 192 |
+
from typing import List, Optional
|
| 193 |
+
|
| 194 |
+
app = FastAPI(
|
| 195 |
+
title="Auto-FineTune-Ops Inference API",
|
| 196 |
+
description="API for serving fine-tuned LLM models",
|
| 197 |
+
version="1.0.0"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# CORS middleware
|
| 201 |
+
app.add_middleware(
|
| 202 |
+
CORSMiddleware,
|
| 203 |
+
allow_origins=["*"],
|
| 204 |
+
allow_credentials=True,
|
| 205 |
+
allow_methods=["*"],
|
| 206 |
+
allow_headers=["*"],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Pydantic models for API
|
| 210 |
+
class GenerateRequest(BaseModel):
|
| 211 |
+
prompt: str
|
| 212 |
+
system_prompt: Optional[str] = None
|
| 213 |
+
max_tokens: int = 512
|
| 214 |
+
temperature: float = 0.7
|
| 215 |
+
top_p: float = 0.9
|
| 216 |
+
|
| 217 |
+
class GenerateResponse(BaseModel):
|
| 218 |
+
generated_text: str
|
| 219 |
+
prompt: str
|
| 220 |
+
model: str
|
| 221 |
+
tokens_generated: int
|
| 222 |
+
generation_time: float
|
| 223 |
+
|
| 224 |
+
class BatchGenerateRequest(BaseModel):
|
| 225 |
+
prompts: List[str]
|
| 226 |
+
system_prompt: Optional[str] = None
|
| 227 |
+
max_tokens: int = 512
|
| 228 |
+
temperature: float = 0.7
|
| 229 |
+
top_p: float = 0.9
|
| 230 |
+
|
| 231 |
+
class HealthResponse(BaseModel):
|
| 232 |
+
status: str
|
| 233 |
+
model: str
|
| 234 |
+
model_loaded: bool
|
| 235 |
+
|
| 236 |
+
@app.get("/health", response_model=HealthResponse)
|
| 237 |
+
async def health_check():
|
| 238 |
+
"""Health check endpoint."""
|
| 239 |
+
return HealthResponse(
|
| 240 |
+
status="healthy",
|
| 241 |
+
model=self.model_path,
|
| 242 |
+
model_loaded=self.model is not None
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
@app.post("/generate", response_model=GenerateResponse)
|
| 246 |
+
async def generate_text(request: GenerateRequest):
|
| 247 |
+
"""Generate text from a single prompt."""
|
| 248 |
+
if self.model is None:
|
| 249 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
result = self.generate(
|
| 253 |
+
prompt=request.prompt,
|
| 254 |
+
system_prompt=request.system_prompt,
|
| 255 |
+
max_tokens=request.max_tokens,
|
| 256 |
+
temperature=request.temperature,
|
| 257 |
+
top_p=request.top_p
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
return GenerateResponse(
|
| 261 |
+
generated_text=result.generated_text,
|
| 262 |
+
prompt=result.prompt,
|
| 263 |
+
model=result.model,
|
| 264 |
+
tokens_generated=result.tokens_generated,
|
| 265 |
+
generation_time=result.generation_time
|
| 266 |
+
)
|
| 267 |
+
except Exception as e:
|
| 268 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 269 |
+
|
| 270 |
+
@app.post("/generate/batch", response_model=List[GenerateResponse])
|
| 271 |
+
async def batch_generate(request: BatchGenerateRequest):
|
| 272 |
+
"""Generate text from multiple prompts."""
|
| 273 |
+
if self.model is None:
|
| 274 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 275 |
+
|
| 276 |
+
results = []
|
| 277 |
+
for prompt in request.prompts:
|
| 278 |
+
try:
|
| 279 |
+
result = self.generate(
|
| 280 |
+
prompt=prompt,
|
| 281 |
+
system_prompt=request.system_prompt,
|
| 282 |
+
max_tokens=request.max_tokens,
|
| 283 |
+
temperature=request.temperature,
|
| 284 |
+
top_p=request.top_p
|
| 285 |
+
)
|
| 286 |
+
results.append(GenerateResponse(
|
| 287 |
+
generated_text=result.generated_text,
|
| 288 |
+
prompt=result.prompt,
|
| 289 |
+
model=result.model,
|
| 290 |
+
tokens_generated=result.tokens_generated,
|
| 291 |
+
generation_time=result.generation_time
|
| 292 |
+
))
|
| 293 |
+
except Exception as e:
|
| 294 |
+
results.append(GenerateResponse(
|
| 295 |
+
generated_text=f"Error: {str(e)}",
|
| 296 |
+
prompt=prompt,
|
| 297 |
+
model=self.model_path,
|
| 298 |
+
tokens_generated=0,
|
| 299 |
+
generation_time=0.0
|
| 300 |
+
))
|
| 301 |
+
|
| 302 |
+
return results
|
| 303 |
+
|
| 304 |
+
@app.get("/")
|
| 305 |
+
async def root():
|
| 306 |
+
"""Root endpoint with API info."""
|
| 307 |
+
return {
|
| 308 |
+
"name": "Auto-FineTune-Ops Inference API",
|
| 309 |
+
"version": "1.0.0",
|
| 310 |
+
"model": self.model_path,
|
| 311 |
+
"endpoints": {
|
| 312 |
+
"/health": "Health check",
|
| 313 |
+
"/generate": "Generate text (POST)",
|
| 314 |
+
"/generate/batch": "Batch generation (POST)"
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
self.app = app
|
| 319 |
+
return app
|
| 320 |
+
|
| 321 |
+
def run(self, reload: bool = False):
|
| 322 |
+
"""
|
| 323 |
+
Start the FastAPI server.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
reload: Enable auto-reload for development
|
| 327 |
+
"""
|
| 328 |
+
import uvicorn
|
| 329 |
+
|
| 330 |
+
console.print("\n" + "="*60)
|
| 331 |
+
console.print("[bold magenta]🚀 DEPLOYMENT SERVER[/]")
|
| 332 |
+
console.print("="*60)
|
| 333 |
+
|
| 334 |
+
# Load model if not already loaded
|
| 335 |
+
if self.model is None:
|
| 336 |
+
self.load_model()
|
| 337 |
+
|
| 338 |
+
# Create app if not already created
|
| 339 |
+
if self.app is None:
|
| 340 |
+
self.create_app()
|
| 341 |
+
|
| 342 |
+
console.print(f"\n[bold green]Starting server at http://{self.host}:{self.port}[/]")
|
| 343 |
+
console.print("[dim]Press Ctrl+C to stop[/]\n")
|
| 344 |
+
|
| 345 |
+
uvicorn.run(
|
| 346 |
+
self.app,
|
| 347 |
+
host=self.host,
|
| 348 |
+
port=self.port,
|
| 349 |
+
reload=reload
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def main():
|
| 354 |
+
"""CLI entry point for deployment."""
|
| 355 |
+
import argparse
|
| 356 |
+
|
| 357 |
+
parser = argparse.ArgumentParser(description="Deploy fine-tuned model as API")
|
| 358 |
+
parser.add_argument("--model", required=True, help="Path to fine-tuned model")
|
| 359 |
+
parser.add_argument("--host", default="0.0.0.0", help="Server host")
|
| 360 |
+
parser.add_argument("--port", type=int, default=8000, help="Server port")
|
| 361 |
+
parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
|
| 362 |
+
|
| 363 |
+
args = parser.parse_args()
|
| 364 |
+
|
| 365 |
+
server = DeploymentServer(
|
| 366 |
+
model_path=args.model,
|
| 367 |
+
host=args.host,
|
| 368 |
+
port=args.port
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
server.run(reload=args.reload)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
main()
|