|
|
|
|
|
""" |
|
|
Comprehensive Web Interface for Fine-Tuning and Hosting Mistral Models |
|
|
Provides an easy-to-use UI for training models and hosting them via API |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import subprocess |
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import signal |
|
|
import time |
|
|
import threading |
|
|
import requests |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import torch |
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).parent |
|
|
MODELS_DIR = BASE_DIR / "models" / "msp" |
|
|
FT_DIR = MODELS_DIR / "ft" |
|
|
INFERENCE_DIR = MODELS_DIR / "inference" |
|
|
API_DIR = MODELS_DIR / "api" |
|
|
DATASET_DIR = BASE_DIR / "dataset" |
|
|
UPLOADS_DIR = BASE_DIR / "uploads" |
|
|
UPLOADS_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
sys.path.insert(0, str(MODELS_DIR)) |
|
|
sys.path.insert(0, str(FT_DIR)) |
|
|
sys.path.insert(0, str(INFERENCE_DIR)) |
|
|
|
|
|
|
|
|
training_process = None |
|
|
api_process = None |
|
|
training_log = [] |
|
|
api_log = [] |
|
|
|
|
|
|
|
|
|
|
|
def get_device_info(): |
|
|
"""Get information about available compute devices""" |
|
|
info = [] |
|
|
if torch.cuda.is_available(): |
|
|
for i in range(torch.cuda.device_count()): |
|
|
info.append(f"๐ฎ GPU {i}: {torch.cuda.get_device_name(i)}") |
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
|
info.append("๐ Apple Silicon GPU (MPS) detected") |
|
|
else: |
|
|
info.append("๐ป CPU only (training will be slow)") |
|
|
return "\n".join(info) |
|
|
|
|
|
def get_gpu_recommendations(): |
|
|
"""Get GPU-specific training recommendations""" |
|
|
if not torch.cuda.is_available(): |
|
|
return { |
|
|
"batch_size": 1, |
|
|
"max_length": 512, |
|
|
"info": "โ ๏ธ CPU only - Use minimal settings to avoid memory issues" |
|
|
} |
|
|
|
|
|
|
|
|
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
|
|
|
|
|
if gpu_memory_gb >= 40: |
|
|
return { |
|
|
"batch_size": 4, |
|
|
"max_length": 2048, |
|
|
"lora_r": 32, |
|
|
"lora_alpha": 64, |
|
|
"info": f"๐ High-end GPU ({gpu_memory_gb:.0f}GB) - Recommended for large batches and long sequences" |
|
|
} |
|
|
elif gpu_memory_gb >= 24: |
|
|
return { |
|
|
"batch_size": 2, |
|
|
"max_length": 1536, |
|
|
"lora_r": 16, |
|
|
"lora_alpha": 32, |
|
|
"info": f"๐ช High-capacity GPU ({gpu_memory_gb:.0f}GB) - Good for moderate sequences" |
|
|
} |
|
|
elif gpu_memory_gb >= 16: |
|
|
return { |
|
|
"batch_size": 2, |
|
|
"max_length": 1024, |
|
|
"lora_r": 16, |
|
|
"lora_alpha": 32, |
|
|
"info": f"โ
Mid-range GPU ({gpu_memory_gb:.0f}GB) - Suitable for standard training" |
|
|
} |
|
|
elif gpu_memory_gb >= 8: |
|
|
return { |
|
|
"batch_size": 1, |
|
|
"max_length": 768, |
|
|
"lora_r": 8, |
|
|
"lora_alpha": 16, |
|
|
"info": f"โก Entry-level GPU ({gpu_memory_gb:.0f}GB) - Use smaller sequences" |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"batch_size": 1, |
|
|
"max_length": 512, |
|
|
"lora_r": 8, |
|
|
"lora_alpha": 16, |
|
|
"info": f"โ ๏ธ Low VRAM GPU ({gpu_memory_gb:.0f}GB) - Use minimal settings" |
|
|
} |
|
|
|
|
|
def list_datasets(): |
|
|
"""List available training datasets""" |
|
|
datasets = [] |
|
|
for ext in ["*.jsonl", "*.json"]: |
|
|
datasets.extend(str(f) for f in DATASET_DIR.rglob(ext) if "claude" not in str(f)) |
|
|
datasets.extend(str(f) for f in UPLOADS_DIR.rglob(ext)) |
|
|
return datasets if datasets else ["No datasets found"] |
|
|
|
|
|
def list_models(): |
|
|
"""List available fine-tuned models""" |
|
|
models = [] |
|
|
|
|
|
|
|
|
for item in BASE_DIR.iterdir(): |
|
|
if item.is_dir() and "mistral" in item.name.lower() and not item.name.startswith('.'): |
|
|
models.append(str(item)) |
|
|
|
|
|
|
|
|
ftt_dir = BASE_DIR.parent |
|
|
for item in ftt_dir.iterdir(): |
|
|
if item.is_dir() and "mistral" in item.name.lower(): |
|
|
models.append(str(item)) |
|
|
|
|
|
|
|
|
if MODELS_DIR.exists(): |
|
|
for item in MODELS_DIR.iterdir(): |
|
|
if item.is_dir() and "mistral" in item.name.lower(): |
|
|
models.append(str(item)) |
|
|
|
|
|
return sorted(list(set(models))) if models else ["No models found"] |
|
|
|
|
|
def list_base_models(): |
|
|
"""List available base models for fine-tuning""" |
|
|
base_models = [] |
|
|
|
|
|
|
|
|
local_base = "/workspace/ftt/base_models/Mistral-7B-v0.1" |
|
|
if Path(local_base).exists(): |
|
|
base_models.append(local_base) |
|
|
|
|
|
|
|
|
base_models.extend(list_models()) |
|
|
|
|
|
|
|
|
base_models.append("mistralai/Mistral-7B-v0.1") |
|
|
base_models.append("mistralai/Mistral-7B-Instruct-v0.2") |
|
|
|
|
|
return base_models if base_models else [local_base] |
|
|
|
|
|
def check_api_status(): |
|
|
"""Check if API server is running""" |
|
|
try: |
|
|
response = requests.get("http://localhost:8000/health", timeout=2) |
|
|
if response.status_code == 200: |
|
|
data = response.json() |
|
|
return True, f"โ
API is running\n๐ฏ Model: {data.get('model_path', 'Unknown')}\n๐ป Device: {data.get('device', 'Unknown')}" |
|
|
return False, "โ API returned error" |
|
|
except requests.exceptions.ConnectionError: |
|
|
return False, "โ API is not running" |
|
|
except Exception as e: |
|
|
return False, f"โ Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
def process_uploaded_file(file): |
|
|
"""Handle uploaded dataset file""" |
|
|
if file is None: |
|
|
return None, "โ ๏ธ No file uploaded" |
|
|
|
|
|
try: |
|
|
|
|
|
filename = Path(file.name).name |
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
new_filename = f"{timestamp}_{filename}" |
|
|
save_path = UPLOADS_DIR / new_filename |
|
|
|
|
|
shutil.copy(file.name, save_path) |
|
|
|
|
|
return str(save_path), f"โ
File uploaded successfully: {new_filename}" |
|
|
except Exception as e: |
|
|
return None, f"โ Error uploading file: {str(e)}" |
|
|
|
|
|
def load_huggingface_dataset(dataset_name, split_ratio): |
|
|
"""Load dataset from HuggingFace and split into train/val/test""" |
|
|
try: |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
dataset = load_dataset(dataset_name) |
|
|
|
|
|
|
|
|
if "train" in dataset: |
|
|
data = dataset["train"] |
|
|
else: |
|
|
|
|
|
split_name = list(dataset.keys())[0] |
|
|
data = dataset[split_name] |
|
|
|
|
|
|
|
|
total_size = len(data) |
|
|
train_size = int(total_size * split_ratio / 100) |
|
|
val_size = int(total_size * (100 - split_ratio) / 200) |
|
|
test_size = total_size - train_size - val_size |
|
|
|
|
|
|
|
|
train_data = data.select(range(train_size)) |
|
|
val_data = data.select(range(train_size, train_size + val_size)) |
|
|
test_data = data.select(range(train_size + val_size, total_size)) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
output_dir = UPLOADS_DIR / f"hf_{dataset_name.replace('/', '_')}_{timestamp}" |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
train_path = output_dir / "train.jsonl" |
|
|
val_path = output_dir / "val.jsonl" |
|
|
test_path = output_dir / "test.jsonl" |
|
|
|
|
|
train_data.to_json(train_path) |
|
|
val_data.to_json(val_path) |
|
|
test_data.to_json(test_path) |
|
|
|
|
|
info = f"โ
Dataset loaded and split successfully!\n" |
|
|
info += f"๐ Total samples: {total_size}\n" |
|
|
info += f" โข Train: {train_size} samples\n" |
|
|
info += f" โข Validation: {val_size} samples\n" |
|
|
info += f" โข Test: {test_size} samples\n" |
|
|
info += f"๐ Saved to: {output_dir}" |
|
|
|
|
|
return str(train_path), info |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"โ Error loading HuggingFace dataset: {str(e)}" |
|
|
|
|
|
def split_local_dataset(dataset_path, split_ratio): |
|
|
"""Split local dataset into train/val/test""" |
|
|
try: |
|
|
import pandas as pd |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
|
|
|
if dataset_path.endswith('.jsonl'): |
|
|
data = pd.read_json(dataset_path, lines=True) |
|
|
else: |
|
|
data = pd.read_json(dataset_path) |
|
|
|
|
|
total_size = len(data) |
|
|
|
|
|
|
|
|
train_ratio = split_ratio / 100 |
|
|
val_test_ratio = (100 - split_ratio) / 100 |
|
|
|
|
|
|
|
|
train_data, temp_data = train_test_split(data, train_size=train_ratio, random_state=42) |
|
|
|
|
|
|
|
|
val_data, test_data = train_test_split(temp_data, train_size=0.5, random_state=42) |
|
|
|
|
|
|
|
|
dataset_name = Path(dataset_path).stem |
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
output_dir = UPLOADS_DIR / f"{dataset_name}_split_{timestamp}" |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
train_path = output_dir / "train.jsonl" |
|
|
val_path = output_dir / "val.jsonl" |
|
|
test_path = output_dir / "test.jsonl" |
|
|
|
|
|
train_data.to_json(train_path, orient='records', lines=True) |
|
|
val_data.to_json(val_path, orient='records', lines=True) |
|
|
test_data.to_json(test_path, orient='records', lines=True) |
|
|
|
|
|
info = f"โ
Dataset split successfully!\n" |
|
|
info += f"๐ Total samples: {total_size}\n" |
|
|
info += f" โข Train: {len(train_data)} samples\n" |
|
|
info += f" โข Validation: {len(val_data)} samples\n" |
|
|
info += f" โข Test: {len(test_data)} samples\n" |
|
|
info += f"๐ Saved to: {output_dir}" |
|
|
|
|
|
return str(train_path), info |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"โ Error splitting dataset: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
def start_training( |
|
|
base_model, |
|
|
dataset_path, |
|
|
output_dir, |
|
|
max_length, |
|
|
num_epochs, |
|
|
batch_size, |
|
|
learning_rate, |
|
|
lora_r, |
|
|
lora_alpha |
|
|
): |
|
|
"""Start the fine-tuning process""" |
|
|
global training_process, training_log |
|
|
|
|
|
if training_process is not None and training_process.poll() is None: |
|
|
return "โ ๏ธ Training is already running!", "".join(training_log) |
|
|
|
|
|
|
|
|
if not dataset_path or not os.path.exists(dataset_path): |
|
|
return f"โ Dataset not found: {dataset_path}", "" |
|
|
|
|
|
if not output_dir: |
|
|
output_dir = f"./mistral-finetuned-{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
import shutil |
|
|
import subprocess |
|
|
cache_dir = Path("/workspace/.hf_home/hub/models--mistralai--Mistral-7B-v0.1") |
|
|
|
|
|
|
|
|
training_log.append("๐งน Clearing HuggingFace cache...\n") |
|
|
try: |
|
|
|
|
|
if cache_dir.exists(): |
|
|
subprocess.run(["find", str(cache_dir), "-type", "f", "-delete"], check=False) |
|
|
subprocess.run(["find", str(cache_dir), "-type", "d", "-empty", "-delete"], check=False) |
|
|
|
|
|
subprocess.run(["rm", "-rf", str(cache_dir)], check=False) |
|
|
training_log.append("โ Cache cleared successfully\n") |
|
|
except Exception as e: |
|
|
training_log.append(f"โ ๏ธ Cache clear warning (non-critical): {e}\n") |
|
|
|
|
|
|
|
|
cmd = [ |
|
|
sys.executable, |
|
|
"-u", |
|
|
str(FT_DIR / "finetune_mistral7b.py"), |
|
|
"--base-model", base_model, |
|
|
"--dataset", dataset_path, |
|
|
"--output-dir", output_dir, |
|
|
"--max-length", str(max_length), |
|
|
] |
|
|
|
|
|
|
|
|
config = { |
|
|
"base_model": base_model, |
|
|
"dataset": dataset_path, |
|
|
"output_dir": output_dir, |
|
|
"max_length": max_length, |
|
|
"num_epochs": num_epochs, |
|
|
"batch_size": batch_size, |
|
|
"learning_rate": learning_rate, |
|
|
"lora_r": lora_r, |
|
|
"lora_alpha": lora_alpha, |
|
|
"started_at": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
config_path = os.path.join(output_dir, "training_config.json") |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
training_log = [f"๐ Starting training...\n"] |
|
|
training_log.append(f"๐ Configuration saved to: {config_path}\n") |
|
|
training_log.append(f"๐พ Output directory: {output_dir}\n") |
|
|
training_log.append(f"๐ Dataset: {dataset_path}\n") |
|
|
training_log.append(f"๐ค Base model: {base_model}\n") |
|
|
training_log.append(f"\n{'='*70}\n") |
|
|
training_log.append(f"Training Command:\n{' '.join(cmd)}\n") |
|
|
training_log.append(f"{'='*70}\n\n") |
|
|
|
|
|
|
|
|
try: |
|
|
env = os.environ.copy() |
|
|
env['PYTHONUNBUFFERED'] = '1' |
|
|
|
|
|
training_process = subprocess.Popen( |
|
|
cmd, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.STDOUT, |
|
|
universal_newlines=True, |
|
|
bufsize=1, |
|
|
env=env |
|
|
) |
|
|
|
|
|
|
|
|
def monitor_training(): |
|
|
global training_log |
|
|
for line in training_process.stdout: |
|
|
training_log.append(line) |
|
|
if len(training_log) > 1000: |
|
|
training_log = training_log[-1000:] |
|
|
|
|
|
thread = threading.Thread(target=monitor_training, daemon=True) |
|
|
thread.start() |
|
|
|
|
|
return f"โ
Training started!\n๐ Output: {output_dir}", "Initializing training...", "".join(training_log) |
|
|
|
|
|
except Exception as e: |
|
|
return f"โ Error starting training: {str(e)}", "".join(training_log) |
|
|
|
|
|
def stop_training(): |
|
|
"""Stop the training process""" |
|
|
global training_process, training_log |
|
|
|
|
|
if training_process is None or training_process.poll() is not None: |
|
|
return "โ ๏ธ No training process is running", "Stopped", "".join(training_log) |
|
|
|
|
|
try: |
|
|
training_process.terminate() |
|
|
training_process.wait(timeout=10) |
|
|
training_log.append("\n\n๐ Training stopped by user\n") |
|
|
return "โ
Training stopped", "Stopped by user", "".join(training_log) |
|
|
except subprocess.TimeoutExpired: |
|
|
training_process.kill() |
|
|
training_log.append("\n\nโ ๏ธ Training force-killed\n") |
|
|
return "โ ๏ธ Training force-killed (did not terminate gracefully)", "Force stopped", "".join(training_log) |
|
|
except Exception as e: |
|
|
return f"โ Error stopping training: {str(e)}", "Error", "".join(training_log) |
|
|
|
|
|
def get_training_status(): |
|
|
"""Get current training status""" |
|
|
global training_process, training_log |
|
|
|
|
|
if training_process is None: |
|
|
status = "โช Not started" |
|
|
progress = "Ready to start" |
|
|
elif training_process.poll() is None: |
|
|
status = "๐ข Running" |
|
|
|
|
|
log_text = "".join(training_log) |
|
|
if "epoch" in log_text.lower(): |
|
|
|
|
|
lines = log_text.split('\n') |
|
|
for line in reversed(lines): |
|
|
if 'epoch' in line.lower(): |
|
|
progress = f"Training... {line.strip()}" |
|
|
break |
|
|
else: |
|
|
progress = "Training in progress..." |
|
|
else: |
|
|
progress = "Initializing..." |
|
|
elif training_process.poll() == 0: |
|
|
status = "โ
Completed successfully" |
|
|
progress = "Training complete! Check output directory." |
|
|
else: |
|
|
status = f"โ Failed (exit code: {training_process.poll()})" |
|
|
progress = "Training failed. Check logs for errors." |
|
|
|
|
|
return status, progress, "".join(training_log) |
|
|
|
|
|
def refresh_training_log(): |
|
|
"""Refresh training log display""" |
|
|
global training_log |
|
|
return "".join(training_log) |
|
|
|
|
|
|
|
|
|
|
|
def start_api_server(model_path, host, port): |
|
|
"""Start the API server""" |
|
|
global api_process, api_log |
|
|
|
|
|
if api_process is not None and api_process.poll() is None: |
|
|
return "โ ๏ธ API server is already running!", "".join(api_log) |
|
|
|
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
|
|
|
api_log = [f"๐ Starting API server with HuggingFace model...\n"] |
|
|
api_log.append(f"๐ค HuggingFace Model: {model_path}\n") |
|
|
else: |
|
|
api_log = [f"๐ Starting API server with local model...\n"] |
|
|
api_log.append(f"๐พ Local Model: {model_path}\n") |
|
|
|
|
|
|
|
|
cmd = [ |
|
|
sys.executable, |
|
|
str(API_DIR / "api_server.py"), |
|
|
"--model-path", model_path, |
|
|
"--host", host, |
|
|
"--port", str(port), |
|
|
] |
|
|
|
|
|
api_log.append(f"๐ Host: {host}\n") |
|
|
api_log.append(f"๐ Port: {port}\n") |
|
|
api_log.append(f"\n{'='*70}\n") |
|
|
api_log.append(f"Server Command:\n{' '.join(cmd)}\n") |
|
|
api_log.append(f"{'='*70}\n\n") |
|
|
|
|
|
try: |
|
|
api_process = subprocess.Popen( |
|
|
cmd, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.STDOUT, |
|
|
universal_newlines=True, |
|
|
bufsize=1 |
|
|
) |
|
|
|
|
|
|
|
|
def monitor_api(): |
|
|
global api_log |
|
|
for line in api_process.stdout: |
|
|
api_log.append(line) |
|
|
if len(api_log) > 500: |
|
|
api_log = api_log[-500:] |
|
|
|
|
|
thread = threading.Thread(target=monitor_api, daemon=True) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
time.sleep(3) |
|
|
|
|
|
is_running, status_msg = check_api_status() |
|
|
if is_running: |
|
|
return f"โ
API server started!\n{status_msg}\n\n๐ก Access at: http://{host}:{port}\n๐ Docs at: http://{host}:{port}/docs", "".join(api_log) |
|
|
else: |
|
|
return f"โ ๏ธ API server started but not responding yet. Check logs.", "".join(api_log) |
|
|
|
|
|
except Exception as e: |
|
|
return f"โ Error starting API server: {str(e)}", "".join(api_log) |
|
|
|
|
|
def stop_api_server(): |
|
|
"""Stop the API server""" |
|
|
global api_process, api_log |
|
|
|
|
|
if api_process is None or api_process.poll() is not None: |
|
|
return "โ ๏ธ No API server is running", "".join(api_log) |
|
|
|
|
|
try: |
|
|
api_process.terminate() |
|
|
api_process.wait(timeout=10) |
|
|
api_log.append("\n\n๐ API server stopped by user\n") |
|
|
return "โ
API server stopped", "".join(api_log) |
|
|
except subprocess.TimeoutExpired: |
|
|
api_process.kill() |
|
|
api_log.append("\n\nโ ๏ธ API server force-killed\n") |
|
|
return "โ ๏ธ API server force-killed (did not terminate gracefully)", "".join(api_log) |
|
|
except Exception as e: |
|
|
return f"โ Error stopping API server: {str(e)}", "".join(api_log) |
|
|
|
|
|
def get_api_status(): |
|
|
"""Get current API status""" |
|
|
is_running, status_msg = check_api_status() |
|
|
return status_msg, "".join(api_log) |
|
|
|
|
|
def refresh_api_log(): |
|
|
"""Refresh API log display""" |
|
|
global api_log |
|
|
return "".join(api_log) |
|
|
|
|
|
|
|
|
|
|
|
def test_inference(model_path, prompt, max_length, temperature): |
|
|
"""Test inference with the model""" |
|
|
try: |
|
|
|
|
|
is_running, _ = check_api_status() |
|
|
|
|
|
if is_running: |
|
|
|
|
|
response = requests.post( |
|
|
"http://localhost:8000/api/generate", |
|
|
json={ |
|
|
"prompt": prompt, |
|
|
"max_length": int(max_length), |
|
|
"temperature": float(temperature) |
|
|
}, |
|
|
timeout=120 |
|
|
) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
return f"โ
Response via API:\n\n{result['response']}" |
|
|
else: |
|
|
|
|
|
from inference.inference_mistral7b import load_local_model, generate_with_local_model |
|
|
|
|
|
|
|
|
|
|
|
model, tokenizer = load_local_model(model_path) |
|
|
|
|
|
response = generate_with_local_model( |
|
|
model, tokenizer, prompt, |
|
|
max_length=int(max_length), |
|
|
temperature=float(temperature) |
|
|
) |
|
|
return f"โ
Response via Direct Inference:\n\n{response}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"โ Error during inference: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create the Gradio interface""" |
|
|
|
|
|
|
|
|
gpu_rec = get_gpu_recommendations() |
|
|
|
|
|
with gr.Blocks(title="Mistral Fine-Tuning & Hosting Interface") as app: |
|
|
gr.Markdown("# ๐ Mistral Model Fine-Tuning & Hosting Interface") |
|
|
gr.Markdown("Complete interface for training and deploying Mistral models") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
device_info = get_device_info() |
|
|
gr.Markdown(f"### ๐ป System Information\n{device_info}\n\n{gpu_rec['info']}") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### โ๏ธ System Controls") |
|
|
|
|
|
def kill_gradio_server(): |
|
|
"""Kill the Gradio server process""" |
|
|
import os |
|
|
import signal |
|
|
pid = os.getpid() |
|
|
|
|
|
def delayed_kill(): |
|
|
time.sleep(1) |
|
|
os.kill(pid, signal.SIGTERM) |
|
|
threading.Thread(target=delayed_kill, daemon=True).start() |
|
|
return "๐ Shutting down Gradio server in 1 second...", api_server_status.value |
|
|
|
|
|
def stop_api_control(): |
|
|
"""Stop API server from control panel""" |
|
|
status, _ = stop_api_server() |
|
|
return server_status.value, status |
|
|
|
|
|
server_status = gr.Textbox( |
|
|
label="Gradio Server Status", |
|
|
value="๐ข Running", |
|
|
interactive=False, |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
api_server_status = gr.Textbox( |
|
|
label="API Server Status", |
|
|
value="โช Not started", |
|
|
interactive=False, |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
kill_server_btn = gr.Button("๐ Shutdown Gradio", variant="stop", scale=1) |
|
|
stop_api_btn_control = gr.Button("โน๏ธ Stop API Server", variant="secondary", scale=1) |
|
|
|
|
|
kill_server_btn.click( |
|
|
fn=kill_gradio_server, |
|
|
outputs=[server_status, api_server_status] |
|
|
) |
|
|
|
|
|
stop_api_btn_control.click( |
|
|
fn=stop_api_control, |
|
|
outputs=[server_status, api_server_status] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tabs() as tabs: |
|
|
|
|
|
|
|
|
with gr.Tab("๐ Fine-Tuning"): |
|
|
gr.Markdown("### Configure and start model fine-tuning") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("#### Training Configuration") |
|
|
|
|
|
base_model_input = gr.Dropdown( |
|
|
label="Base Model (Select existing model or HuggingFace ID)", |
|
|
choices=list_base_models(), |
|
|
value=list_base_models()[0] if list_base_models() else "/workspace/ftt/base_models/Mistral-7B-v0.1", |
|
|
allow_custom_value=True, |
|
|
info="๐ก Select a base model to start from, or a fine-tuned model to continue training" |
|
|
) |
|
|
|
|
|
gr.Markdown("#### Dataset Selection") |
|
|
|
|
|
dataset_source = gr.Radio( |
|
|
choices=["Local File", "Upload File", "HuggingFace Dataset"], |
|
|
value="Local File", |
|
|
label="Dataset Source" |
|
|
) |
|
|
|
|
|
|
|
|
dataset_input = gr.Dropdown( |
|
|
label="Select Local Dataset", |
|
|
choices=list_datasets(), |
|
|
value=list_datasets()[0] if list_datasets()[0] != "No datasets found" else None, |
|
|
allow_custom_value=True, |
|
|
visible=True |
|
|
) |
|
|
|
|
|
|
|
|
dataset_upload = gr.File( |
|
|
label="Upload Dataset File (JSON/JSONL)", |
|
|
file_types=[".json", ".jsonl"], |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
hf_dataset_input = gr.Textbox( |
|
|
label="HuggingFace Dataset Name", |
|
|
placeholder="e.g., timdettmers/openassistant-guanaco", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("#### Dataset Processing") |
|
|
split_dataset = gr.Checkbox( |
|
|
label="Split dataset into train/val/test", |
|
|
value=False |
|
|
) |
|
|
|
|
|
split_ratio = gr.Slider( |
|
|
label="Training Split % (remaining split equally between val/test)", |
|
|
minimum=60, |
|
|
maximum=90, |
|
|
value=80, |
|
|
step=5 |
|
|
) |
|
|
|
|
|
process_dataset_btn = gr.Button("๐ Process Dataset") |
|
|
dataset_status = gr.Textbox(label="Dataset Status", interactive=False, lines=6) |
|
|
|
|
|
output_dir_input = gr.Textbox( |
|
|
label="Output Directory", |
|
|
value=f"./mistral-finetuned-{datetime.now().strftime('%Y%m%d_%H%M%S')}", |
|
|
placeholder="Where to save the fine-tuned model" |
|
|
) |
|
|
|
|
|
gr.Markdown("#### Training Parameters") |
|
|
gr.Markdown(f"*๐ก GPU-Optimized Defaults: Batch={gpu_rec['batch_size']}, Max Length={gpu_rec['max_length']}, LoRA Rank={gpu_rec.get('lora_r', 16)}*") |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("**Sequence & Training Settings**") |
|
|
|
|
|
with gr.Row(): |
|
|
max_length_input = gr.Slider( |
|
|
label="Max Sequence Length", |
|
|
info="๐ Tokens per example | Higher=more context but more memory | Standard: 512-2048 | Your GPU: " + str(gpu_rec['max_length']), |
|
|
minimum=128, |
|
|
maximum=6000, |
|
|
value=gpu_rec['max_length'], |
|
|
step=128 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
num_epochs_input = gr.Slider( |
|
|
label="Number of Epochs", |
|
|
info="๐ Training passes | More=better learning but risk overfitting | Standard: 3-5 | Quick test: 1", |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=3, |
|
|
step=1 |
|
|
) |
|
|
|
|
|
batch_size_input = gr.Slider( |
|
|
label="Batch Size", |
|
|
info="๐ฆ Samples together | Larger=faster but more memory | Your GPU: " + str(gpu_rec['batch_size']) + " | Low VRAM: 1", |
|
|
minimum=1, |
|
|
maximum=16, |
|
|
value=gpu_rec['batch_size'], |
|
|
step=1 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
learning_rate_input = gr.Number( |
|
|
label="Learning Rate", |
|
|
info="โก Training speed | Typical: 1e-5 to 5e-4 | Lower=stable | Higher=fast | Default: 5e-5", |
|
|
value=5e-5, |
|
|
precision=6 |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("**LoRA Configuration** *(Efficient fine-tuning by training small parameter subset)*") |
|
|
|
|
|
with gr.Row(): |
|
|
lora_r_input = gr.Slider( |
|
|
label="LoRA Rank (r)", |
|
|
info="๐ฏ Adaptation matrix rank | Higher=more capacity/slower | Standard: 8-32 | Your GPU: " + str(gpu_rec.get('lora_r', 16)), |
|
|
minimum=4, |
|
|
maximum=64, |
|
|
value=gpu_rec.get('lora_r', 16), |
|
|
step=4 |
|
|
) |
|
|
|
|
|
lora_alpha_input = gr.Slider( |
|
|
label="LoRA Alpha", |
|
|
info="โ๏ธ Scaling factor | Typically 2ร rank | Controls adaptation strength | Recommended: " + str(gpu_rec.get('lora_alpha', 32)), |
|
|
minimum=8, |
|
|
maximum=128, |
|
|
value=gpu_rec.get('lora_alpha', 32), |
|
|
step=8 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
start_train_btn = gr.Button("โถ๏ธ Start Training", variant="primary") |
|
|
stop_train_btn = gr.Button("โน๏ธ Stop Training", variant="stop") |
|
|
refresh_train_btn = gr.Button("๐ Refresh Status") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("#### Training Status & Logs") |
|
|
|
|
|
training_status_output = gr.Textbox( |
|
|
label="Status (Right-click to copy)", |
|
|
value="โช Not started", |
|
|
interactive=False, |
|
|
lines=2, |
|
|
max_lines=3 |
|
|
) |
|
|
|
|
|
training_progress = gr.Textbox( |
|
|
label="Progress - Epoch/Loss Info (Right-click to copy)", |
|
|
value="Ready to start", |
|
|
interactive=False, |
|
|
lines=2, |
|
|
max_lines=3 |
|
|
) |
|
|
|
|
|
training_log_output = gr.Textbox( |
|
|
label="Training Logs - Scrollable (Click Refresh to update, Right-click to copy)", |
|
|
lines=22, |
|
|
max_lines=22, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
def update_dataset_visibility(source): |
|
|
return ( |
|
|
gr.update(visible=(source == "Local File")), |
|
|
gr.update(visible=(source == "Upload File")), |
|
|
gr.update(visible=(source == "HuggingFace Dataset")) |
|
|
) |
|
|
|
|
|
dataset_source.change( |
|
|
fn=update_dataset_visibility, |
|
|
inputs=[dataset_source], |
|
|
outputs=[dataset_input, dataset_upload, hf_dataset_input] |
|
|
) |
|
|
|
|
|
|
|
|
def process_dataset(source, local_path, uploaded_file, hf_name, should_split, ratio): |
|
|
if source == "Upload File": |
|
|
if uploaded_file is None: |
|
|
return None, "โ ๏ธ Please upload a file" |
|
|
path, msg = process_uploaded_file(uploaded_file) |
|
|
if path and should_split: |
|
|
path, msg = split_local_dataset(path, ratio) |
|
|
return path, msg |
|
|
elif source == "HuggingFace Dataset": |
|
|
if not hf_name: |
|
|
return None, "โ ๏ธ Please enter a HuggingFace dataset name" |
|
|
return load_huggingface_dataset(hf_name, ratio) |
|
|
else: |
|
|
if not local_path or local_path == "No datasets found": |
|
|
return None, "โ ๏ธ Please select a dataset" |
|
|
if should_split: |
|
|
return split_local_dataset(local_path, ratio) |
|
|
return local_path, f"โ
Using existing dataset: {local_path}" |
|
|
|
|
|
process_dataset_btn.click( |
|
|
fn=process_dataset, |
|
|
inputs=[dataset_source, dataset_input, dataset_upload, hf_dataset_input, split_dataset, split_ratio], |
|
|
outputs=[dataset_input, dataset_status] |
|
|
) |
|
|
|
|
|
|
|
|
start_train_btn.click( |
|
|
fn=start_training, |
|
|
inputs=[ |
|
|
base_model_input, dataset_input, output_dir_input, |
|
|
max_length_input, num_epochs_input, batch_size_input, |
|
|
learning_rate_input, lora_r_input, lora_alpha_input |
|
|
], |
|
|
outputs=[training_status_output, training_progress, training_log_output] |
|
|
) |
|
|
|
|
|
stop_train_btn.click( |
|
|
fn=stop_training, |
|
|
outputs=[training_status_output, training_progress, training_log_output] |
|
|
) |
|
|
|
|
|
refresh_train_btn.click( |
|
|
fn=get_training_status, |
|
|
outputs=[training_status_output, training_progress, training_log_output] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("๐ API Hosting"): |
|
|
gr.Markdown("### Start and manage API server for model inference") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("#### Server Configuration") |
|
|
|
|
|
api_model_source = gr.Radio( |
|
|
choices=["Local Model", "HuggingFace Model"], |
|
|
value="Local Model", |
|
|
label="Model Source" |
|
|
) |
|
|
|
|
|
api_model_input = gr.Dropdown( |
|
|
label="Select Local Model", |
|
|
choices=list_models(), |
|
|
value=list_models()[0] if list_models()[0] != "No models found" else None, |
|
|
allow_custom_value=True |
|
|
) |
|
|
|
|
|
api_hf_model_input = gr.Textbox( |
|
|
label="HuggingFace Model ID", |
|
|
placeholder="e.g., mistralai/Mistral-7B-v0.1 or your-username/your-model", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
api_host_input = gr.Textbox( |
|
|
label="Host", |
|
|
value="0.0.0.0", |
|
|
placeholder="0.0.0.0 for all interfaces" |
|
|
) |
|
|
|
|
|
api_port_input = gr.Number( |
|
|
label="Port", |
|
|
value=8000, |
|
|
precision=0 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
start_api_btn = gr.Button("โถ๏ธ Start Server", variant="primary") |
|
|
stop_api_btn = gr.Button("โน๏ธ Stop Server", variant="stop") |
|
|
refresh_api_btn = gr.Button("๐ Refresh Status") |
|
|
|
|
|
api_status_output = gr.Textbox( |
|
|
label="Server Status", |
|
|
value="โช Not started", |
|
|
interactive=False, |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("#### Server Logs") |
|
|
|
|
|
api_log_output = gr.Textbox( |
|
|
label="API Server Logs", |
|
|
lines=35, |
|
|
max_lines=35, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
def update_api_model_visibility(source): |
|
|
return ( |
|
|
gr.update(visible=(source == "Local Model")), |
|
|
gr.update(visible=(source == "HuggingFace Model")) |
|
|
) |
|
|
|
|
|
api_model_source.change( |
|
|
fn=update_api_model_visibility, |
|
|
inputs=[api_model_source], |
|
|
outputs=[api_model_input, api_hf_model_input] |
|
|
) |
|
|
|
|
|
|
|
|
def start_api_wrapper(source, local_model, hf_model, host, port): |
|
|
model_path = hf_model if source == "HuggingFace Model" else local_model |
|
|
if not model_path: |
|
|
return "โ ๏ธ Please select or enter a model", "" |
|
|
return start_api_server(model_path, host, port) |
|
|
|
|
|
start_api_btn.click( |
|
|
fn=start_api_wrapper, |
|
|
inputs=[api_model_source, api_model_input, api_hf_model_input, api_host_input, api_port_input], |
|
|
outputs=[api_status_output, api_log_output] |
|
|
) |
|
|
|
|
|
stop_api_btn.click( |
|
|
fn=stop_api_server, |
|
|
outputs=[api_status_output, api_log_output] |
|
|
) |
|
|
|
|
|
refresh_api_btn.click( |
|
|
fn=get_api_status, |
|
|
outputs=[api_status_output, api_log_output] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("๐งช Test Inference"): |
|
|
gr.Markdown("### Test your fine-tuned models") |
|
|
gr.Markdown("๐ก The interface will use the API if it's running, otherwise it will load the model directly") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
inference_model_source = gr.Radio( |
|
|
choices=["Local Model", "HuggingFace Model"], |
|
|
value="Local Model", |
|
|
label="Model Source" |
|
|
) |
|
|
|
|
|
inference_model_input = gr.Dropdown( |
|
|
label="Select Local Model", |
|
|
choices=list_models(), |
|
|
value=list_models()[0] if list_models()[0] != "No models found" else None, |
|
|
allow_custom_value=True |
|
|
) |
|
|
|
|
|
inference_hf_model_input = gr.Textbox( |
|
|
label="HuggingFace Model ID", |
|
|
placeholder="e.g., mistralai/Mistral-7B-v0.1", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
gr.Markdown("#### Prompt Configuration") |
|
|
|
|
|
inference_system_instruction = gr.Textbox( |
|
|
label="System Instruction (Pre-filled, editable)", |
|
|
lines=4, |
|
|
value="You are Elinnos RTL Code Generator v1.0, a specialized Verilog/SystemVerilog code generation agent. Your role: Generate clean, synthesizable RTL code for hardware design tasks. Output ONLY functional RTL code with no $display, assertions, comments, or debug statements.", |
|
|
info="๐ก This is pre-filled with your model's training format. Edit if needed." |
|
|
) |
|
|
|
|
|
inference_user_prompt = gr.Textbox( |
|
|
label="User Prompt (Your request)", |
|
|
lines=3, |
|
|
placeholder="Example: Generate a synchronous FIFO with 8-bit data width, depth 4, write_enable, read_enable, full flag, empty flag.", |
|
|
value="" |
|
|
) |
|
|
|
|
|
gr.Markdown("#### Generation Parameters") |
|
|
|
|
|
with gr.Row(): |
|
|
inference_max_length = gr.Slider( |
|
|
label="Max Length", |
|
|
info="Maximum tokens to generate. Higher = longer responses but slower", |
|
|
minimum=128, |
|
|
maximum=6000, |
|
|
value=512, |
|
|
step=128 |
|
|
) |
|
|
|
|
|
inference_temperature = gr.Slider( |
|
|
label="Temperature", |
|
|
info="Creativity control: 0.1=focused/deterministic, 1.0=creative/random", |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=0.7, |
|
|
step=0.1 |
|
|
) |
|
|
|
|
|
inference_btn = gr.Button("๐ Generate", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
inference_output = gr.Textbox( |
|
|
label="Generated Response", |
|
|
lines=30, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
def update_inference_model_visibility(source): |
|
|
return ( |
|
|
gr.update(visible=(source == "Local Model")), |
|
|
gr.update(visible=(source == "HuggingFace Model")) |
|
|
) |
|
|
|
|
|
inference_model_source.change( |
|
|
fn=update_inference_model_visibility, |
|
|
inputs=[inference_model_source], |
|
|
outputs=[inference_model_input, inference_hf_model_input] |
|
|
) |
|
|
|
|
|
|
|
|
def test_inference_wrapper(source, local_model, hf_model, system_instruction, user_prompt, max_len, temp): |
|
|
model_path = hf_model if source == "HuggingFace Model" else local_model |
|
|
if not model_path: |
|
|
return "โ ๏ธ Please select or enter a model" |
|
|
|
|
|
|
|
|
full_prompt = f"{system_instruction}\n\nUser:\n{user_prompt}" |
|
|
|
|
|
return test_inference(model_path, full_prompt, max_len, temp) |
|
|
|
|
|
inference_btn.click( |
|
|
fn=test_inference_wrapper, |
|
|
inputs=[inference_model_source, inference_model_input, inference_hf_model_input, |
|
|
inference_system_instruction, inference_user_prompt, inference_max_length, inference_temperature], |
|
|
outputs=inference_output |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("๐ Documentation"): |
|
|
gr.Markdown(""" |
|
|
## ๐ User Guide |
|
|
|
|
|
### ๐ Fine-Tuning |
|
|
|
|
|
#### Dataset Options |
|
|
|
|
|
**1. Local File**: Select from existing datasets in the workspace |
|
|
- Use datasets already present in the `dataset/` directory |
|
|
|
|
|
**2. Upload File**: Upload your own dataset file |
|
|
- Supported formats: JSON, JSONL |
|
|
- Files are saved to `uploads/` directory |
|
|
|
|
|
**3. HuggingFace Dataset**: Load from HuggingFace Hub |
|
|
- Enter dataset name (e.g., `timdettmers/openassistant-guanaco`) |
|
|
- Automatically downloaded and processed |
|
|
|
|
|
#### Dataset Processing |
|
|
|
|
|
- **Split Dataset**: Automatically split into train/validation/test sets |
|
|
- **Split Ratio**: Control train percentage (default 80%) |
|
|
- Remaining data split equally between validation and test |
|
|
|
|
|
#### Training Parameters Explained |
|
|
|
|
|
**Max Sequence Length** |
|
|
- Number of tokens (words/subwords) per training example |
|
|
- Higher = more context but requires more GPU memory |
|
|
- Standard: 512-2048, Maximum: 6000 (for long documents) |
|
|
- **Recommendation**: Start with GPU-recommended value |
|
|
|
|
|
**Number of Epochs** |
|
|
- How many complete passes through your dataset |
|
|
- More epochs = better learning but risk overfitting |
|
|
- Standard: 3-5 epochs |
|
|
- Watch training loss to avoid overfitting |
|
|
|
|
|
**Batch Size** |
|
|
- Number of examples processed simultaneously |
|
|
- Larger = faster training but more memory |
|
|
- Limited by your GPU memory |
|
|
- **GPU-based recommendations provided automatically** |
|
|
|
|
|
**Learning Rate** |
|
|
- Controls how quickly the model adapts |
|
|
- Too high = unstable training, too low = slow convergence |
|
|
- Standard: 1e-5 to 5e-4 |
|
|
- Default 5e-5 works well for most cases |
|
|
|
|
|
**LoRA Rank (r)** |
|
|
- Rank of low-rank adaptation matrices |
|
|
- Higher = more model capacity but slower training |
|
|
- Standard: 8-32 |
|
|
- Use lower values for smaller datasets |
|
|
|
|
|
**LoRA Alpha** |
|
|
- Scaling factor for LoRA updates |
|
|
- Typically set to 2ร the rank |
|
|
- Controls strength of fine-tuning adaptations |
|
|
|
|
|
### ๐ API Hosting |
|
|
|
|
|
#### Model Sources |
|
|
|
|
|
**Local Model**: Models saved on your machine |
|
|
- Fine-tuned models from training |
|
|
- Downloaded HuggingFace models |
|
|
|
|
|
**HuggingFace Model**: Direct from HuggingFace Hub |
|
|
- Enter model ID (e.g., `mistralai/Mistral-7B-v0.1`) |
|
|
- No need to download first |
|
|
- Automatically cached after first use |
|
|
|
|
|
#### API Endpoints |
|
|
|
|
|
Once running, access these endpoints: |
|
|
- **Generate**: `POST http://localhost:8000/api/generate` |
|
|
- **Health**: `GET http://localhost:8000/health` |
|
|
- **Docs**: `http://localhost:8000/docs` (Interactive API docs) |
|
|
|
|
|
### ๐งช Testing Inference |
|
|
|
|
|
#### Model Selection |
|
|
|
|
|
- **Local Model**: Use models from your filesystem |
|
|
- **HuggingFace Model**: Test any model from HuggingFace Hub |
|
|
|
|
|
#### Generation Parameters |
|
|
|
|
|
**Max Length** |
|
|
- Maximum number of tokens to generate |
|
|
- Higher = longer responses but slower generation |
|
|
- Balance between quality and speed |
|
|
- Typical: 256-1024 for most tasks |
|
|
|
|
|
**Temperature** |
|
|
- Controls randomness in generation |
|
|
- **0.1-0.3**: Very focused, deterministic (good for factual tasks) |
|
|
- **0.5-0.7**: Balanced creativity (default, recommended) |
|
|
- **0.8-1.0**: Creative, diverse outputs |
|
|
- **1.0+**: Very random (experimental, often incoherent) |
|
|
|
|
|
### ๐ก Tips & Best Practices |
|
|
|
|
|
#### GPU Memory Management |
|
|
- **Out of Memory?** Reduce batch size or max sequence length |
|
|
- Monitor GPU usage with `nvidia-smi` |
|
|
- Use gradient checkpointing for very long sequences |
|
|
|
|
|
#### Training Tips |
|
|
- **Start Small**: Test with a small subset first |
|
|
- **Monitor Loss**: Should decrease steadily |
|
|
- **Early Stopping**: Stop if validation loss increases |
|
|
- **Save Checkpoints**: Training saves to output directory |
|
|
|
|
|
#### Dataset Quality |
|
|
- **Format Consistency**: Ensure all examples follow the same format |
|
|
- **Quality over Quantity**: 1000 good examples > 10000 poor ones |
|
|
- **Diverse Examples**: Cover different aspects of your task |
|
|
|
|
|
#### Model Selection |
|
|
- **Base Model**: Start with Mistral-7B-v0.1 (good balance) |
|
|
- **Fine-tuned Models**: Use domain-specific if available |
|
|
- **Test First**: Always test inference before production |
|
|
|
|
|
### ๐ง Dataset Format |
|
|
|
|
|
Your training data should be in JSONL format (one JSON object per line): |
|
|
|
|
|
**Format 1: Instruction-Response** |
|
|
```json |
|
|
{"instruction": "Your question or task", "response": "Expected answer"} |
|
|
``` |
|
|
|
|
|
**Format 2: Prompt-Completion** |
|
|
```json |
|
|
{"prompt": "Your question", "completion": "Expected answer"} |
|
|
``` |
|
|
|
|
|
**Format 3: Chat Format** |
|
|
```json |
|
|
{"messages": [ |
|
|
{"role": "user", "content": "Question"}, |
|
|
{"role": "assistant", "content": "Answer"} |
|
|
]} |
|
|
``` |
|
|
|
|
|
### ๐จ Troubleshooting |
|
|
|
|
|
**Training Issues** |
|
|
- **Out of Memory**: Reduce batch size, max sequence length, or LoRA rank |
|
|
- **Slow Training**: Check GPU utilization, ensure CUDA is available |
|
|
- **NaN Loss**: Reduce learning rate or check data quality |
|
|
- **No Improvement**: Increase epochs, learning rate, or dataset size |
|
|
|
|
|
**API Issues** |
|
|
- **Server Won't Start**: Check if port is already in use |
|
|
- **Connection Refused**: Ensure firewall allows the port |
|
|
- **Slow Inference**: Model loading can take time on first request |
|
|
- **Out of Memory**: Model too large for GPU, use smaller model or CPU |
|
|
|
|
|
**Model Issues** |
|
|
- **Model Not Found**: Verify path or HuggingFace model ID |
|
|
- **Poor Quality**: May need more training data or epochs |
|
|
- **Inconsistent Output**: Adjust temperature or use lower value |
|
|
|
|
|
### ๐ Performance Benchmarks |
|
|
|
|
|
**GPU Memory Requirements (Mistral-7B)** |
|
|
- Training (LoRA): ~12-16GB VRAM |
|
|
- Inference: ~8-10GB VRAM |
|
|
- Batch size 1: minimum required |
|
|
- Batch size 4: optimal on 40GB GPU |
|
|
|
|
|
**Training Speed (A100 40GB)** |
|
|
- ~5000 tokens/second |
|
|
- 10k examples: ~30-60 minutes |
|
|
- Depends on sequence length and batch size |
|
|
|
|
|
### ๐ Recent Updates |
|
|
|
|
|
**v2.0 Features** |
|
|
- โ
File upload for datasets |
|
|
- โ
HuggingFace dataset integration |
|
|
- โ
Automatic dataset splitting (train/val/test) |
|
|
- โ
Extended max sequence length to 6000 tokens |
|
|
- โ
GPU-specific parameter recommendations |
|
|
- โ
HuggingFace model support for API hosting |
|
|
- โ
HuggingFace model support for inference |
|
|
- โ
Enhanced parameter tooltips and descriptions |
|
|
- โ
Public URL sharing enabled |
|
|
- โ
Improved documentation |
|
|
|
|
|
### ๐ Support |
|
|
|
|
|
For issues or questions: |
|
|
- Check logs for error messages |
|
|
- Verify GPU availability and memory |
|
|
- Ensure all dependencies are installed |
|
|
- Review dataset format and quality |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Launch the application""" |
|
|
print("=" * 70) |
|
|
print("๐ Mistral Fine-Tuning & Hosting Interface v2.0") |
|
|
print("=" * 70) |
|
|
print(f"\n๐ป System Information:") |
|
|
print(get_device_info()) |
|
|
gpu_rec = get_gpu_recommendations() |
|
|
print(f"\n{gpu_rec['info']}") |
|
|
print(f"\n๐ Base Directory: {BASE_DIR}") |
|
|
print(f"๐ Available Datasets: {len(list_datasets())}") |
|
|
print(f"๐ค Available Models: {len(list_models())}") |
|
|
print("\n" + "=" * 70) |
|
|
print("๐ Starting web interface...") |
|
|
print("=" * 70 + "\n") |
|
|
|
|
|
app = create_interface() |
|
|
app.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=True, |
|
|
show_error=True |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|