Prithvik-1 commited on
Commit
2987c51
·
verified ·
1 Parent(s): d65b167

Upload interface_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. interface_app.py +1330 -0
interface_app.py ADDED
@@ -0,0 +1,1330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive Web Interface for Fine-Tuning and Hosting Mistral Models
4
+ Provides an easy-to-use UI for training models and hosting them via API
5
+ """
6
+
7
+ import gradio as gr
8
+ import subprocess
9
+ import os
10
+ import sys
11
+ import json
12
+ import signal
13
+ import time
14
+ import threading
15
+ import requests
16
+ import shutil
17
+ from pathlib import Path
18
+ from datetime import datetime
19
+ import torch
20
+
21
+ # Add project paths
22
+ BASE_DIR = Path(__file__).parent
23
+ MODELS_DIR = BASE_DIR / "models" / "msp"
24
+ FT_DIR = MODELS_DIR / "ft"
25
+ INFERENCE_DIR = MODELS_DIR / "inference"
26
+ API_DIR = MODELS_DIR / "api"
27
+ DATASET_DIR = BASE_DIR / "dataset"
28
+ UPLOADS_DIR = BASE_DIR / "uploads"
29
+ UPLOADS_DIR.mkdir(exist_ok=True)
30
+
31
+ sys.path.insert(0, str(MODELS_DIR))
32
+ sys.path.insert(0, str(FT_DIR))
33
+ sys.path.insert(0, str(INFERENCE_DIR))
34
+
35
+ # Global process trackers
36
+ training_process = None
37
+ api_process = None
38
+ training_log = []
39
+ api_log = []
40
+
41
+ # ==================== UTILITY FUNCTIONS ====================
42
+
43
+ def get_device_info():
44
+ """Get information about available compute devices"""
45
+ info = []
46
+ if torch.cuda.is_available():
47
+ for i in range(torch.cuda.device_count()):
48
+ info.append(f"🎮 GPU {i}: {torch.cuda.get_device_name(i)}")
49
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
50
+ info.append("🍎 Apple Silicon GPU (MPS) detected")
51
+ else:
52
+ info.append("💻 CPU only (training will be slow)")
53
+ return "\n".join(info)
54
+
55
+ def get_gpu_recommendations():
56
+ """Get GPU-specific training recommendations"""
57
+ if not torch.cuda.is_available():
58
+ return {
59
+ "batch_size": 1,
60
+ "max_length": 512,
61
+ "info": "⚠️ CPU only - Use minimal settings to avoid memory issues"
62
+ }
63
+
64
+ # Get GPU memory in GB
65
+ gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
66
+
67
+ if gpu_memory_gb >= 40: # A100 40GB or similar
68
+ return {
69
+ "batch_size": 4,
70
+ "max_length": 2048,
71
+ "lora_r": 32,
72
+ "lora_alpha": 64,
73
+ "info": f"🚀 High-end GPU ({gpu_memory_gb:.0f}GB) - Recommended for large batches and long sequences"
74
+ }
75
+ elif gpu_memory_gb >= 24: # RTX 3090/4090 24GB
76
+ return {
77
+ "batch_size": 2,
78
+ "max_length": 1536,
79
+ "lora_r": 16,
80
+ "lora_alpha": 32,
81
+ "info": f"💪 High-capacity GPU ({gpu_memory_gb:.0f}GB) - Good for moderate sequences"
82
+ }
83
+ elif gpu_memory_gb >= 16: # RTX 4060 Ti 16GB
84
+ return {
85
+ "batch_size": 2,
86
+ "max_length": 1024,
87
+ "lora_r": 16,
88
+ "lora_alpha": 32,
89
+ "info": f"✅ Mid-range GPU ({gpu_memory_gb:.0f}GB) - Suitable for standard training"
90
+ }
91
+ elif gpu_memory_gb >= 8: # RTX 3060 8GB
92
+ return {
93
+ "batch_size": 1,
94
+ "max_length": 768,
95
+ "lora_r": 8,
96
+ "lora_alpha": 16,
97
+ "info": f"⚡ Entry-level GPU ({gpu_memory_gb:.0f}GB) - Use smaller sequences"
98
+ }
99
+ else:
100
+ return {
101
+ "batch_size": 1,
102
+ "max_length": 512,
103
+ "lora_r": 8,
104
+ "lora_alpha": 16,
105
+ "info": f"⚠️ Low VRAM GPU ({gpu_memory_gb:.0f}GB) - Use minimal settings"
106
+ }
107
+
108
+ def list_datasets():
109
+ """List available training datasets"""
110
+ datasets = []
111
+ for ext in ["*.jsonl", "*.json"]:
112
+ datasets.extend(str(f) for f in DATASET_DIR.rglob(ext) if "claude" not in str(f))
113
+ datasets.extend(str(f) for f in UPLOADS_DIR.rglob(ext))
114
+ return datasets if datasets else ["No datasets found"]
115
+
116
+ def list_models():
117
+ """List available fine-tuned models"""
118
+ models = []
119
+
120
+ # Check in BASE_DIR (semicon-finetuning-scripts directory)
121
+ for item in BASE_DIR.iterdir():
122
+ if item.is_dir() and "mistral" in item.name.lower() and not item.name.startswith('.'):
123
+ models.append(str(item))
124
+
125
+ # Check in BASE_DIR parent (ftt directory)
126
+ ftt_dir = BASE_DIR.parent
127
+ for item in ftt_dir.iterdir():
128
+ if item.is_dir() and "mistral" in item.name.lower():
129
+ models.append(str(item))
130
+
131
+ # Check in MODELS_DIR
132
+ if MODELS_DIR.exists():
133
+ for item in MODELS_DIR.iterdir():
134
+ if item.is_dir() and "mistral" in item.name.lower():
135
+ models.append(str(item))
136
+
137
+ return sorted(list(set(models))) if models else ["No models found"]
138
+
139
+ def list_base_models():
140
+ """List available base models for fine-tuning"""
141
+ base_models = []
142
+
143
+ # Add the local base model
144
+ local_base = "/workspace/ftt/base_models/Mistral-7B-v0.1"
145
+ if Path(local_base).exists():
146
+ base_models.append(local_base)
147
+
148
+ # Add all fine-tuned models (can be used as base for further training)
149
+ base_models.extend(list_models())
150
+
151
+ # Add HuggingFace model IDs
152
+ base_models.append("mistralai/Mistral-7B-v0.1")
153
+ base_models.append("mistralai/Mistral-7B-Instruct-v0.2")
154
+
155
+ return base_models if base_models else [local_base]
156
+
157
+ def check_api_status():
158
+ """Check if API server is running"""
159
+ try:
160
+ response = requests.get("http://localhost:8000/health", timeout=2)
161
+ if response.status_code == 200:
162
+ data = response.json()
163
+ return True, f"✅ API is running\n🎯 Model: {data.get('model_path', 'Unknown')}\n💻 Device: {data.get('device', 'Unknown')}"
164
+ return False, "❌ API returned error"
165
+ except requests.exceptions.ConnectionError:
166
+ return False, "❌ API is not running"
167
+ except Exception as e:
168
+ return False, f"❌ Error: {str(e)}"
169
+
170
+ # ==================== DATASET FUNCTIONS ====================
171
+
172
+ def process_uploaded_file(file):
173
+ """Handle uploaded dataset file"""
174
+ if file is None:
175
+ return None, "⚠️ No file uploaded"
176
+
177
+ try:
178
+ # Save uploaded file
179
+ filename = Path(file.name).name
180
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
181
+ new_filename = f"{timestamp}_{filename}"
182
+ save_path = UPLOADS_DIR / new_filename
183
+
184
+ shutil.copy(file.name, save_path)
185
+
186
+ return str(save_path), f"✅ File uploaded successfully: {new_filename}"
187
+ except Exception as e:
188
+ return None, f"❌ Error uploading file: {str(e)}"
189
+
190
+ def load_huggingface_dataset(dataset_name, split_ratio):
191
+ """Load dataset from HuggingFace and split into train/val/test"""
192
+ try:
193
+ from datasets import load_dataset
194
+
195
+ # Load dataset
196
+ dataset = load_dataset(dataset_name)
197
+
198
+ # Get the appropriate split
199
+ if "train" in dataset:
200
+ data = dataset["train"]
201
+ else:
202
+ # Use the first available split
203
+ split_name = list(dataset.keys())[0]
204
+ data = dataset[split_name]
205
+
206
+ # Calculate split sizes
207
+ total_size = len(data)
208
+ train_size = int(total_size * split_ratio / 100)
209
+ val_size = int(total_size * (100 - split_ratio) / 200)
210
+ test_size = total_size - train_size - val_size
211
+
212
+ # Split dataset
213
+ train_data = data.select(range(train_size))
214
+ val_data = data.select(range(train_size, train_size + val_size))
215
+ test_data = data.select(range(train_size + val_size, total_size))
216
+
217
+ # Save splits
218
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
219
+ output_dir = UPLOADS_DIR / f"hf_{dataset_name.replace('/', '_')}_{timestamp}"
220
+ output_dir.mkdir(exist_ok=True)
221
+
222
+ train_path = output_dir / "train.jsonl"
223
+ val_path = output_dir / "val.jsonl"
224
+ test_path = output_dir / "test.jsonl"
225
+
226
+ train_data.to_json(train_path)
227
+ val_data.to_json(val_path)
228
+ test_data.to_json(test_path)
229
+
230
+ info = f"✅ Dataset loaded and split successfully!\n"
231
+ info += f"📊 Total samples: {total_size}\n"
232
+ info += f" • Train: {train_size} samples\n"
233
+ info += f" • Validation: {val_size} samples\n"
234
+ info += f" • Test: {test_size} samples\n"
235
+ info += f"📁 Saved to: {output_dir}"
236
+
237
+ return str(train_path), info
238
+
239
+ except Exception as e:
240
+ return None, f"❌ Error loading HuggingFace dataset: {str(e)}"
241
+
242
+ def split_local_dataset(dataset_path, split_ratio):
243
+ """Split local dataset into train/val/test"""
244
+ try:
245
+ import pandas as pd
246
+ from sklearn.model_selection import train_test_split
247
+
248
+ # Read dataset
249
+ if dataset_path.endswith('.jsonl'):
250
+ data = pd.read_json(dataset_path, lines=True)
251
+ else:
252
+ data = pd.read_json(dataset_path)
253
+
254
+ total_size = len(data)
255
+
256
+ # Calculate split sizes
257
+ train_ratio = split_ratio / 100
258
+ val_test_ratio = (100 - split_ratio) / 100
259
+
260
+ # First split: train vs (val + test)
261
+ train_data, temp_data = train_test_split(data, train_size=train_ratio, random_state=42)
262
+
263
+ # Second split: val vs test (50-50 of remaining)
264
+ val_data, test_data = train_test_split(temp_data, train_size=0.5, random_state=42)
265
+
266
+ # Save splits
267
+ dataset_name = Path(dataset_path).stem
268
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
269
+ output_dir = UPLOADS_DIR / f"{dataset_name}_split_{timestamp}"
270
+ output_dir.mkdir(exist_ok=True)
271
+
272
+ train_path = output_dir / "train.jsonl"
273
+ val_path = output_dir / "val.jsonl"
274
+ test_path = output_dir / "test.jsonl"
275
+
276
+ train_data.to_json(train_path, orient='records', lines=True)
277
+ val_data.to_json(val_path, orient='records', lines=True)
278
+ test_data.to_json(test_path, orient='records', lines=True)
279
+
280
+ info = f"✅ Dataset split successfully!\n"
281
+ info += f"📊 Total samples: {total_size}\n"
282
+ info += f" • Train: {len(train_data)} samples\n"
283
+ info += f" • Validation: {len(val_data)} samples\n"
284
+ info += f" • Test: {len(test_data)} samples\n"
285
+ info += f"📁 Saved to: {output_dir}"
286
+
287
+ return str(train_path), info
288
+
289
+ except Exception as e:
290
+ return None, f"❌ Error splitting dataset: {str(e)}"
291
+
292
+ # ==================== TRAINING FUNCTIONS ====================
293
+
294
+ def start_training(
295
+ base_model,
296
+ dataset_path,
297
+ output_dir,
298
+ max_length,
299
+ num_epochs,
300
+ batch_size,
301
+ learning_rate,
302
+ lora_r,
303
+ lora_alpha
304
+ ):
305
+ """Start the fine-tuning process"""
306
+ global training_process, training_log
307
+
308
+ if training_process is not None and training_process.poll() is None:
309
+ return "⚠️ Training is already running!", "".join(training_log)
310
+
311
+ # Validate inputs
312
+ if not dataset_path or not os.path.exists(dataset_path):
313
+ return f"❌ Dataset not found: {dataset_path}", ""
314
+
315
+ if not output_dir:
316
+ output_dir = f"./mistral-finetuned-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
317
+
318
+ # Create output directory
319
+ os.makedirs(output_dir, exist_ok=True)
320
+
321
+ # Clear HF cache before training to avoid stale file handles
322
+ import shutil
323
+ import subprocess
324
+ cache_dir = Path("/workspace/.hf_home/hub/models--mistralai--Mistral-7B-v0.1")
325
+
326
+ # Try multiple methods to clear cache
327
+ training_log.append("🧹 Clearing HuggingFace cache...\n")
328
+ try:
329
+ # Method 1: Remove all files first
330
+ if cache_dir.exists():
331
+ subprocess.run(["find", str(cache_dir), "-type", "f", "-delete"], check=False)
332
+ subprocess.run(["find", str(cache_dir), "-type", "d", "-empty", "-delete"], check=False)
333
+ # Method 2: Force remove directory
334
+ subprocess.run(["rm", "-rf", str(cache_dir)], check=False)
335
+ training_log.append("✓ Cache cleared successfully\n")
336
+ except Exception as e:
337
+ training_log.append(f"⚠️ Cache clear warning (non-critical): {e}\n")
338
+
339
+ # Build command with unbuffered output
340
+ cmd = [
341
+ sys.executable,
342
+ "-u", # Unbuffered output for real-time logs
343
+ str(FT_DIR / "finetune_mistral7b.py"),
344
+ "--base-model", base_model,
345
+ "--dataset", dataset_path,
346
+ "--output-dir", output_dir,
347
+ "--max-length", str(max_length),
348
+ ]
349
+
350
+ # Save configuration
351
+ config = {
352
+ "base_model": base_model,
353
+ "dataset": dataset_path,
354
+ "output_dir": output_dir,
355
+ "max_length": max_length,
356
+ "num_epochs": num_epochs,
357
+ "batch_size": batch_size,
358
+ "learning_rate": learning_rate,
359
+ "lora_r": lora_r,
360
+ "lora_alpha": lora_alpha,
361
+ "started_at": datetime.now().isoformat()
362
+ }
363
+
364
+ config_path = os.path.join(output_dir, "training_config.json")
365
+ with open(config_path, 'w') as f:
366
+ json.dump(config, f, indent=2)
367
+
368
+ training_log = [f"🚀 Starting training...\n"]
369
+ training_log.append(f"📊 Configuration saved to: {config_path}\n")
370
+ training_log.append(f"💾 Output directory: {output_dir}\n")
371
+ training_log.append(f"📁 Dataset: {dataset_path}\n")
372
+ training_log.append(f"🤖 Base model: {base_model}\n")
373
+ training_log.append(f"\n{'='*70}\n")
374
+ training_log.append(f"Training Command:\n{' '.join(cmd)}\n")
375
+ training_log.append(f"{'='*70}\n\n")
376
+
377
+ # Start training process with environment for unbuffered output
378
+ try:
379
+ env = os.environ.copy()
380
+ env['PYTHONUNBUFFERED'] = '1' # Force unbuffered output
381
+
382
+ training_process = subprocess.Popen(
383
+ cmd,
384
+ stdout=subprocess.PIPE,
385
+ stderr=subprocess.STDOUT,
386
+ universal_newlines=True,
387
+ bufsize=1,
388
+ env=env
389
+ )
390
+
391
+ # Start log monitoring thread
392
+ def monitor_training():
393
+ global training_log
394
+ for line in training_process.stdout:
395
+ training_log.append(line)
396
+ if len(training_log) > 1000: # Keep last 1000 lines
397
+ training_log = training_log[-1000:]
398
+
399
+ thread = threading.Thread(target=monitor_training, daemon=True)
400
+ thread.start()
401
+
402
+ return f"✅ Training started!\n📂 Output: {output_dir}", "Initializing training...", "".join(training_log)
403
+
404
+ except Exception as e:
405
+ return f"❌ Error starting training: {str(e)}", "".join(training_log)
406
+
407
+ def stop_training():
408
+ """Stop the training process"""
409
+ global training_process, training_log
410
+
411
+ if training_process is None or training_process.poll() is not None:
412
+ return "⚠️ No training process is running", "Stopped", "".join(training_log)
413
+
414
+ try:
415
+ training_process.terminate()
416
+ training_process.wait(timeout=10)
417
+ training_log.append("\n\n🛑 Training stopped by user\n")
418
+ return "✅ Training stopped", "Stopped by user", "".join(training_log)
419
+ except subprocess.TimeoutExpired:
420
+ training_process.kill()
421
+ training_log.append("\n\n⚠️ Training force-killed\n")
422
+ return "⚠️ Training force-killed (did not terminate gracefully)", "Force stopped", "".join(training_log)
423
+ except Exception as e:
424
+ return f"❌ Error stopping training: {str(e)}", "Error", "".join(training_log)
425
+
426
+ def get_training_status():
427
+ """Get current training status"""
428
+ global training_process, training_log
429
+
430
+ if training_process is None:
431
+ status = "⚪ Not started"
432
+ progress = "Ready to start"
433
+ elif training_process.poll() is None:
434
+ status = "🟢 Running"
435
+ # Try to extract progress from logs
436
+ log_text = "".join(training_log)
437
+ if "epoch" in log_text.lower():
438
+ # Extract last epoch info
439
+ lines = log_text.split('\n')
440
+ for line in reversed(lines):
441
+ if 'epoch' in line.lower():
442
+ progress = f"Training... {line.strip()}"
443
+ break
444
+ else:
445
+ progress = "Training in progress..."
446
+ else:
447
+ progress = "Initializing..."
448
+ elif training_process.poll() == 0:
449
+ status = "✅ Completed successfully"
450
+ progress = "Training complete! Check output directory."
451
+ else:
452
+ status = f"❌ Failed (exit code: {training_process.poll()})"
453
+ progress = "Training failed. Check logs for errors."
454
+
455
+ return status, progress, "".join(training_log)
456
+
457
+ def refresh_training_log():
458
+ """Refresh training log display"""
459
+ global training_log
460
+ return "".join(training_log)
461
+
462
+ # ==================== API HOSTING FUNCTIONS ====================
463
+
464
+ def start_api_server(model_path, host, port):
465
+ """Start the API server"""
466
+ global api_process, api_log
467
+
468
+ if api_process is not None and api_process.poll() is None:
469
+ return "⚠️ API server is already running!", "".join(api_log)
470
+
471
+ # Check if it's a HuggingFace model (doesn't exist locally)
472
+ if not os.path.exists(model_path):
473
+ # Assume it's a HuggingFace model ID
474
+ api_log = [f"🚀 Starting API server with HuggingFace model...\n"]
475
+ api_log.append(f"🤗 HuggingFace Model: {model_path}\n")
476
+ else:
477
+ api_log = [f"🚀 Starting API server with local model...\n"]
478
+ api_log.append(f"💾 Local Model: {model_path}\n")
479
+
480
+ # Build command
481
+ cmd = [
482
+ sys.executable,
483
+ str(API_DIR / "api_server.py"),
484
+ "--model-path", model_path,
485
+ "--host", host,
486
+ "--port", str(port),
487
+ ]
488
+
489
+ api_log.append(f"🌐 Host: {host}\n")
490
+ api_log.append(f"🔌 Port: {port}\n")
491
+ api_log.append(f"\n{'='*70}\n")
492
+ api_log.append(f"Server Command:\n{' '.join(cmd)}\n")
493
+ api_log.append(f"{'='*70}\n\n")
494
+
495
+ try:
496
+ api_process = subprocess.Popen(
497
+ cmd,
498
+ stdout=subprocess.PIPE,
499
+ stderr=subprocess.STDOUT,
500
+ universal_newlines=True,
501
+ bufsize=1
502
+ )
503
+
504
+ # Start log monitoring thread
505
+ def monitor_api():
506
+ global api_log
507
+ for line in api_process.stdout:
508
+ api_log.append(line)
509
+ if len(api_log) > 500: # Keep last 500 lines
510
+ api_log = api_log[-500:]
511
+
512
+ thread = threading.Thread(target=monitor_api, daemon=True)
513
+ thread.start()
514
+
515
+ # Wait a bit for server to start
516
+ time.sleep(3)
517
+
518
+ is_running, status_msg = check_api_status()
519
+ if is_running:
520
+ return f"✅ API server started!\n{status_msg}\n\n📡 Access at: http://{host}:{port}\n📚 Docs at: http://{host}:{port}/docs", "".join(api_log)
521
+ else:
522
+ return f"⚠️ API server started but not responding yet. Check logs.", "".join(api_log)
523
+
524
+ except Exception as e:
525
+ return f"❌ Error starting API server: {str(e)}", "".join(api_log)
526
+
527
+ def stop_api_server():
528
+ """Stop the API server"""
529
+ global api_process, api_log
530
+
531
+ if api_process is None or api_process.poll() is not None:
532
+ return "⚠️ No API server is running", "".join(api_log)
533
+
534
+ try:
535
+ api_process.terminate()
536
+ api_process.wait(timeout=10)
537
+ api_log.append("\n\n🛑 API server stopped by user\n")
538
+ return "✅ API server stopped", "".join(api_log)
539
+ except subprocess.TimeoutExpired:
540
+ api_process.kill()
541
+ api_log.append("\n\n⚠️ API server force-killed\n")
542
+ return "⚠️ API server force-killed (did not terminate gracefully)", "".join(api_log)
543
+ except Exception as e:
544
+ return f"❌ Error stopping API server: {str(e)}", "".join(api_log)
545
+
546
+ def get_api_status():
547
+ """Get current API status"""
548
+ is_running, status_msg = check_api_status()
549
+ return status_msg, "".join(api_log)
550
+
551
+ def refresh_api_log():
552
+ """Refresh API log display"""
553
+ global api_log
554
+ return "".join(api_log)
555
+
556
+ # ==================== INFERENCE FUNCTIONS ====================
557
+
558
+ def test_inference(model_path, prompt, max_length, temperature):
559
+ """Test inference with the model"""
560
+ try:
561
+ # Check if API is running first
562
+ is_running, _ = check_api_status()
563
+
564
+ if is_running:
565
+ # Use API
566
+ response = requests.post(
567
+ "http://localhost:8000/api/generate",
568
+ json={
569
+ "prompt": prompt,
570
+ "max_length": int(max_length),
571
+ "temperature": float(temperature)
572
+ },
573
+ timeout=120
574
+ )
575
+ response.raise_for_status()
576
+ result = response.json()
577
+ return f"✅ Response via API:\n\n{result['response']}"
578
+ else:
579
+ # Use direct inference
580
+ from inference.inference_mistral7b import load_local_model, generate_with_local_model
581
+
582
+ # Check if it's a local path or HuggingFace model
583
+ # Load model regardless of source
584
+ model, tokenizer = load_local_model(model_path)
585
+
586
+ response = generate_with_local_model(
587
+ model, tokenizer, prompt,
588
+ max_length=int(max_length),
589
+ temperature=float(temperature)
590
+ )
591
+ return f"✅ Response via Direct Inference:\n\n{response}"
592
+
593
+ except Exception as e:
594
+ return f"❌ Error during inference: {str(e)}"
595
+
596
+ # ==================== UI CREATION ====================
597
+
598
+ def create_interface():
599
+ """Create the Gradio interface"""
600
+
601
+ # Get GPU recommendations
602
+ gpu_rec = get_gpu_recommendations()
603
+
604
+ with gr.Blocks(title="Mistral Fine-Tuning & Hosting Interface") as app:
605
+ gr.Markdown("# 🚀 Mistral Model Fine-Tuning & Hosting Interface")
606
+ gr.Markdown("Complete interface for training and deploying Mistral models")
607
+
608
+ # Device info and controls
609
+ with gr.Row():
610
+ with gr.Column(scale=3):
611
+ device_info = get_device_info()
612
+ gr.Markdown(f"### 💻 System Information\n{device_info}\n\n{gpu_rec['info']}")
613
+
614
+ with gr.Column(scale=1):
615
+ gr.Markdown("### ⚙️ System Controls")
616
+
617
+ def kill_gradio_server():
618
+ """Kill the Gradio server process"""
619
+ import os
620
+ import signal
621
+ pid = os.getpid()
622
+ # Schedule the kill to happen after this function returns
623
+ def delayed_kill():
624
+ time.sleep(1)
625
+ os.kill(pid, signal.SIGTERM)
626
+ threading.Thread(target=delayed_kill, daemon=True).start()
627
+ return "🛑 Shutting down Gradio server in 1 second...", api_server_status.value
628
+
629
+ def stop_api_control():
630
+ """Stop API server from control panel"""
631
+ status, _ = stop_api_server()
632
+ return server_status.value, status
633
+
634
+ server_status = gr.Textbox(
635
+ label="Gradio Server Status",
636
+ value="🟢 Running",
637
+ interactive=False,
638
+ lines=1
639
+ )
640
+
641
+ api_server_status = gr.Textbox(
642
+ label="API Server Status",
643
+ value="⚪ Not started",
644
+ interactive=False,
645
+ lines=1
646
+ )
647
+
648
+ with gr.Row():
649
+ kill_server_btn = gr.Button("🛑 Shutdown Gradio", variant="stop", scale=1)
650
+ stop_api_btn_control = gr.Button("⏹️ Stop API Server", variant="secondary", scale=1)
651
+
652
+ kill_server_btn.click(
653
+ fn=kill_gradio_server,
654
+ outputs=[server_status, api_server_status]
655
+ )
656
+
657
+ stop_api_btn_control.click(
658
+ fn=stop_api_control,
659
+ outputs=[server_status, api_server_status]
660
+ )
661
+
662
+ # Main tabs
663
+ with gr.Tabs() as tabs:
664
+
665
+ # ========== FINE-TUNING TAB ==========
666
+ with gr.Tab("🎓 Fine-Tuning"):
667
+ gr.Markdown("### Configure and start model fine-tuning")
668
+
669
+ with gr.Row():
670
+ with gr.Column(scale=1):
671
+ gr.Markdown("#### Training Configuration")
672
+
673
+ base_model_input = gr.Dropdown(
674
+ label="Base Model (Select existing model or HuggingFace ID)",
675
+ choices=list_base_models(),
676
+ value=list_base_models()[0] if list_base_models() else "/workspace/ftt/base_models/Mistral-7B-v0.1",
677
+ allow_custom_value=True,
678
+ info="💡 Select a base model to start from, or a fine-tuned model to continue training"
679
+ )
680
+
681
+ gr.Markdown("#### Dataset Selection")
682
+
683
+ dataset_source = gr.Radio(
684
+ choices=["Local File", "Upload File", "HuggingFace Dataset"],
685
+ value="Local File",
686
+ label="Dataset Source"
687
+ )
688
+
689
+ # Local file selection
690
+ dataset_input = gr.Dropdown(
691
+ label="Select Local Dataset",
692
+ choices=list_datasets(),
693
+ value=list_datasets()[0] if list_datasets()[0] != "No datasets found" else None,
694
+ allow_custom_value=True,
695
+ visible=True
696
+ )
697
+
698
+ # File upload
699
+ dataset_upload = gr.File(
700
+ label="Upload Dataset File (JSON/JSONL)",
701
+ file_types=[".json", ".jsonl"],
702
+ visible=False
703
+ )
704
+
705
+ # HuggingFace dataset
706
+ hf_dataset_input = gr.Textbox(
707
+ label="HuggingFace Dataset Name",
708
+ placeholder="e.g., timdettmers/openassistant-guanaco",
709
+ visible=False
710
+ )
711
+
712
+ # Dataset splitting
713
+ gr.Markdown("#### Dataset Processing")
714
+ split_dataset = gr.Checkbox(
715
+ label="Split dataset into train/val/test",
716
+ value=False
717
+ )
718
+
719
+ split_ratio = gr.Slider(
720
+ label="Training Split % (remaining split equally between val/test)",
721
+ minimum=60,
722
+ maximum=90,
723
+ value=80,
724
+ step=5
725
+ )
726
+
727
+ process_dataset_btn = gr.Button("📊 Process Dataset")
728
+ dataset_status = gr.Textbox(label="Dataset Status", interactive=False, lines=6)
729
+
730
+ output_dir_input = gr.Textbox(
731
+ label="Output Directory",
732
+ value=f"./mistral-finetuned-{datetime.now().strftime('%Y%m%d_%H%M%S')}",
733
+ placeholder="Where to save the fine-tuned model"
734
+ )
735
+
736
+ gr.Markdown("#### Training Parameters")
737
+ gr.Markdown(f"*💡 GPU-Optimized Defaults: Batch={gpu_rec['batch_size']}, Max Length={gpu_rec['max_length']}, LoRA Rank={gpu_rec.get('lora_r', 16)}*")
738
+
739
+ gr.Markdown("---")
740
+ gr.Markdown("**Sequence & Training Settings**")
741
+
742
+ with gr.Row():
743
+ max_length_input = gr.Slider(
744
+ label="Max Sequence Length",
745
+ info="📏 Tokens per example | Higher=more context but more memory | Standard: 512-2048 | Your GPU: " + str(gpu_rec['max_length']),
746
+ minimum=128,
747
+ maximum=6000,
748
+ value=gpu_rec['max_length'],
749
+ step=128
750
+ )
751
+
752
+ with gr.Row():
753
+ num_epochs_input = gr.Slider(
754
+ label="Number of Epochs",
755
+ info="🔁 Training passes | More=better learning but risk overfitting | Standard: 3-5 | Quick test: 1",
756
+ minimum=1,
757
+ maximum=10,
758
+ value=3,
759
+ step=1
760
+ )
761
+
762
+ batch_size_input = gr.Slider(
763
+ label="Batch Size",
764
+ info="📦 Samples together | Larger=faster but more memory | Your GPU: " + str(gpu_rec['batch_size']) + " | Low VRAM: 1",
765
+ minimum=1,
766
+ maximum=16,
767
+ value=gpu_rec['batch_size'],
768
+ step=1
769
+ )
770
+
771
+ with gr.Row():
772
+ learning_rate_input = gr.Number(
773
+ label="Learning Rate",
774
+ info="⚡ Training speed | Typical: 1e-5 to 5e-4 | Lower=stable | Higher=fast | Default: 5e-5",
775
+ value=5e-5,
776
+ precision=6
777
+ )
778
+
779
+ gr.Markdown("---")
780
+ gr.Markdown("**LoRA Configuration** *(Efficient fine-tuning by training small parameter subset)*")
781
+
782
+ with gr.Row():
783
+ lora_r_input = gr.Slider(
784
+ label="LoRA Rank (r)",
785
+ info="🎯 Adaptation matrix rank | Higher=more capacity/slower | Standard: 8-32 | Your GPU: " + str(gpu_rec.get('lora_r', 16)),
786
+ minimum=4,
787
+ maximum=64,
788
+ value=gpu_rec.get('lora_r', 16),
789
+ step=4
790
+ )
791
+
792
+ lora_alpha_input = gr.Slider(
793
+ label="LoRA Alpha",
794
+ info="⚖️ Scaling factor | Typically 2× rank | Controls adaptation strength | Recommended: " + str(gpu_rec.get('lora_alpha', 32)),
795
+ minimum=8,
796
+ maximum=128,
797
+ value=gpu_rec.get('lora_alpha', 32),
798
+ step=8
799
+ )
800
+
801
+ with gr.Row():
802
+ start_train_btn = gr.Button("▶️ Start Training", variant="primary")
803
+ stop_train_btn = gr.Button("⏹️ Stop Training", variant="stop")
804
+ refresh_train_btn = gr.Button("🔄 Refresh Status")
805
+
806
+ with gr.Column(scale=2):
807
+ gr.Markdown("#### Training Status & Logs")
808
+
809
+ training_status_output = gr.Textbox(
810
+ label="Status (Right-click to copy)",
811
+ value="⚪ Not started",
812
+ interactive=False,
813
+ lines=2,
814
+ max_lines=3
815
+ )
816
+
817
+ training_progress = gr.Textbox(
818
+ label="Progress - Epoch/Loss Info (Right-click to copy)",
819
+ value="Ready to start",
820
+ interactive=False,
821
+ lines=2,
822
+ max_lines=3
823
+ )
824
+
825
+ training_log_output = gr.Textbox(
826
+ label="Training Logs - Scrollable (Click Refresh to update, Right-click to copy)",
827
+ lines=22,
828
+ max_lines=22,
829
+ interactive=False
830
+ )
831
+
832
+ # Dataset source switching
833
+ def update_dataset_visibility(source):
834
+ return (
835
+ gr.update(visible=(source == "Local File")),
836
+ gr.update(visible=(source == "Upload File")),
837
+ gr.update(visible=(source == "HuggingFace Dataset"))
838
+ )
839
+
840
+ dataset_source.change(
841
+ fn=update_dataset_visibility,
842
+ inputs=[dataset_source],
843
+ outputs=[dataset_input, dataset_upload, hf_dataset_input]
844
+ )
845
+
846
+ # Dataset processing
847
+ def process_dataset(source, local_path, uploaded_file, hf_name, should_split, ratio):
848
+ if source == "Upload File":
849
+ if uploaded_file is None:
850
+ return None, "⚠️ Please upload a file"
851
+ path, msg = process_uploaded_file(uploaded_file)
852
+ if path and should_split:
853
+ path, msg = split_local_dataset(path, ratio)
854
+ return path, msg
855
+ elif source == "HuggingFace Dataset":
856
+ if not hf_name:
857
+ return None, "⚠️ Please enter a HuggingFace dataset name"
858
+ return load_huggingface_dataset(hf_name, ratio)
859
+ else: # Local File
860
+ if not local_path or local_path == "No datasets found":
861
+ return None, "⚠️ Please select a dataset"
862
+ if should_split:
863
+ return split_local_dataset(local_path, ratio)
864
+ return local_path, f"✅ Using existing dataset: {local_path}"
865
+
866
+ process_dataset_btn.click(
867
+ fn=process_dataset,
868
+ inputs=[dataset_source, dataset_input, dataset_upload, hf_dataset_input, split_dataset, split_ratio],
869
+ outputs=[dataset_input, dataset_status]
870
+ )
871
+
872
+ # Connect training buttons
873
+ start_train_btn.click(
874
+ fn=start_training,
875
+ inputs=[
876
+ base_model_input, dataset_input, output_dir_input,
877
+ max_length_input, num_epochs_input, batch_size_input,
878
+ learning_rate_input, lora_r_input, lora_alpha_input
879
+ ],
880
+ outputs=[training_status_output, training_progress, training_log_output]
881
+ )
882
+
883
+ stop_train_btn.click(
884
+ fn=stop_training,
885
+ outputs=[training_status_output, training_progress, training_log_output]
886
+ )
887
+
888
+ refresh_train_btn.click(
889
+ fn=get_training_status,
890
+ outputs=[training_status_output, training_progress, training_log_output]
891
+ )
892
+
893
+ # ========== API HOSTING TAB ==========
894
+ with gr.Tab("🌐 API Hosting"):
895
+ gr.Markdown("### Start and manage API server for model inference")
896
+
897
+ with gr.Row():
898
+ with gr.Column(scale=1):
899
+ gr.Markdown("#### Server Configuration")
900
+
901
+ api_model_source = gr.Radio(
902
+ choices=["Local Model", "HuggingFace Model"],
903
+ value="Local Model",
904
+ label="Model Source"
905
+ )
906
+
907
+ api_model_input = gr.Dropdown(
908
+ label="Select Local Model",
909
+ choices=list_models(),
910
+ value=list_models()[0] if list_models()[0] != "No models found" else None,
911
+ allow_custom_value=True
912
+ )
913
+
914
+ api_hf_model_input = gr.Textbox(
915
+ label="HuggingFace Model ID",
916
+ placeholder="e.g., mistralai/Mistral-7B-v0.1 or your-username/your-model",
917
+ visible=False
918
+ )
919
+
920
+ api_host_input = gr.Textbox(
921
+ label="Host",
922
+ value="0.0.0.0",
923
+ placeholder="0.0.0.0 for all interfaces"
924
+ )
925
+
926
+ api_port_input = gr.Number(
927
+ label="Port",
928
+ value=8000,
929
+ precision=0
930
+ )
931
+
932
+ with gr.Row():
933
+ start_api_btn = gr.Button("▶️ Start Server", variant="primary")
934
+ stop_api_btn = gr.Button("⏹️ Stop Server", variant="stop")
935
+ refresh_api_btn = gr.Button("🔄 Refresh Status")
936
+
937
+ api_status_output = gr.Textbox(
938
+ label="Server Status",
939
+ value="⚪ Not started",
940
+ interactive=False,
941
+ lines=5
942
+ )
943
+
944
+ with gr.Column(scale=2):
945
+ gr.Markdown("#### Server Logs")
946
+
947
+ api_log_output = gr.Textbox(
948
+ label="API Server Logs",
949
+ lines=35,
950
+ max_lines=35,
951
+ interactive=False
952
+ )
953
+
954
+ # Model source switching
955
+ def update_api_model_visibility(source):
956
+ return (
957
+ gr.update(visible=(source == "Local Model")),
958
+ gr.update(visible=(source == "HuggingFace Model"))
959
+ )
960
+
961
+ api_model_source.change(
962
+ fn=update_api_model_visibility,
963
+ inputs=[api_model_source],
964
+ outputs=[api_model_input, api_hf_model_input]
965
+ )
966
+
967
+ # API server buttons
968
+ def start_api_wrapper(source, local_model, hf_model, host, port):
969
+ model_path = hf_model if source == "HuggingFace Model" else local_model
970
+ if not model_path:
971
+ return "⚠️ Please select or enter a model", ""
972
+ return start_api_server(model_path, host, port)
973
+
974
+ start_api_btn.click(
975
+ fn=start_api_wrapper,
976
+ inputs=[api_model_source, api_model_input, api_hf_model_input, api_host_input, api_port_input],
977
+ outputs=[api_status_output, api_log_output]
978
+ )
979
+
980
+ stop_api_btn.click(
981
+ fn=stop_api_server,
982
+ outputs=[api_status_output, api_log_output]
983
+ )
984
+
985
+ refresh_api_btn.click(
986
+ fn=get_api_status,
987
+ outputs=[api_status_output, api_log_output]
988
+ )
989
+
990
+ # ========== INFERENCE TAB ==========
991
+ with gr.Tab("🧪 Test Inference"):
992
+ gr.Markdown("### Test your fine-tuned models")
993
+ gr.Markdown("💡 The interface will use the API if it's running, otherwise it will load the model directly")
994
+
995
+ with gr.Row():
996
+ with gr.Column(scale=1):
997
+ inference_model_source = gr.Radio(
998
+ choices=["Local Model", "HuggingFace Model"],
999
+ value="Local Model",
1000
+ label="Model Source"
1001
+ )
1002
+
1003
+ inference_model_input = gr.Dropdown(
1004
+ label="Select Local Model",
1005
+ choices=list_models(),
1006
+ value=list_models()[0] if list_models()[0] != "No models found" else None,
1007
+ allow_custom_value=True
1008
+ )
1009
+
1010
+ inference_hf_model_input = gr.Textbox(
1011
+ label="HuggingFace Model ID",
1012
+ placeholder="e.g., mistralai/Mistral-7B-v0.1",
1013
+ visible=False
1014
+ )
1015
+
1016
+ gr.Markdown("#### Prompt Configuration")
1017
+
1018
+ inference_system_instruction = gr.Textbox(
1019
+ label="System Instruction (Pre-filled, editable)",
1020
+ lines=4,
1021
+ value="You are Elinnos RTL Code Generator v1.0, a specialized Verilog/SystemVerilog code generation agent. Your role: Generate clean, synthesizable RTL code for hardware design tasks. Output ONLY functional RTL code with no $display, assertions, comments, or debug statements.",
1022
+ info="💡 This is pre-filled with your model's training format. Edit if needed."
1023
+ )
1024
+
1025
+ inference_user_prompt = gr.Textbox(
1026
+ label="User Prompt (Your request)",
1027
+ lines=3,
1028
+ placeholder="Example: Generate a synchronous FIFO with 8-bit data width, depth 4, write_enable, read_enable, full flag, empty flag.",
1029
+ value=""
1030
+ )
1031
+
1032
+ gr.Markdown("#### Generation Parameters")
1033
+
1034
+ with gr.Row():
1035
+ inference_max_length = gr.Slider(
1036
+ label="Max Length",
1037
+ info="Maximum tokens to generate. Higher = longer responses but slower",
1038
+ minimum=128,
1039
+ maximum=6000,
1040
+ value=512,
1041
+ step=128
1042
+ )
1043
+
1044
+ inference_temperature = gr.Slider(
1045
+ label="Temperature",
1046
+ info="Creativity control: 0.1=focused/deterministic, 1.0=creative/random",
1047
+ minimum=0.1,
1048
+ maximum=2.0,
1049
+ value=0.7,
1050
+ step=0.1
1051
+ )
1052
+
1053
+ inference_btn = gr.Button("🚀 Generate", variant="primary")
1054
+
1055
+ with gr.Column(scale=2):
1056
+ inference_output = gr.Textbox(
1057
+ label="Generated Response",
1058
+ lines=30,
1059
+ interactive=False
1060
+ )
1061
+
1062
+ # Model source switching
1063
+ def update_inference_model_visibility(source):
1064
+ return (
1065
+ gr.update(visible=(source == "Local Model")),
1066
+ gr.update(visible=(source == "HuggingFace Model"))
1067
+ )
1068
+
1069
+ inference_model_source.change(
1070
+ fn=update_inference_model_visibility,
1071
+ inputs=[inference_model_source],
1072
+ outputs=[inference_model_input, inference_hf_model_input]
1073
+ )
1074
+
1075
+ # Inference
1076
+ def test_inference_wrapper(source, local_model, hf_model, system_instruction, user_prompt, max_len, temp):
1077
+ model_path = hf_model if source == "HuggingFace Model" else local_model
1078
+ if not model_path:
1079
+ return "⚠️ Please select or enter a model"
1080
+
1081
+ # Combine system instruction and user prompt
1082
+ full_prompt = f"{system_instruction}\n\nUser:\n{user_prompt}"
1083
+
1084
+ return test_inference(model_path, full_prompt, max_len, temp)
1085
+
1086
+ inference_btn.click(
1087
+ fn=test_inference_wrapper,
1088
+ inputs=[inference_model_source, inference_model_input, inference_hf_model_input,
1089
+ inference_system_instruction, inference_user_prompt, inference_max_length, inference_temperature],
1090
+ outputs=inference_output
1091
+ )
1092
+
1093
+ # ========== DOCUMENTATION TAB ==========
1094
+ with gr.Tab("📚 Documentation"):
1095
+ gr.Markdown("""
1096
+ ## 📖 User Guide
1097
+
1098
+ ### 🎓 Fine-Tuning
1099
+
1100
+ #### Dataset Options
1101
+
1102
+ **1. Local File**: Select from existing datasets in the workspace
1103
+ - Use datasets already present in the `dataset/` directory
1104
+
1105
+ **2. Upload File**: Upload your own dataset file
1106
+ - Supported formats: JSON, JSONL
1107
+ - Files are saved to `uploads/` directory
1108
+
1109
+ **3. HuggingFace Dataset**: Load from HuggingFace Hub
1110
+ - Enter dataset name (e.g., `timdettmers/openassistant-guanaco`)
1111
+ - Automatically downloaded and processed
1112
+
1113
+ #### Dataset Processing
1114
+
1115
+ - **Split Dataset**: Automatically split into train/validation/test sets
1116
+ - **Split Ratio**: Control train percentage (default 80%)
1117
+ - Remaining data split equally between validation and test
1118
+
1119
+ #### Training Parameters Explained
1120
+
1121
+ **Max Sequence Length**
1122
+ - Number of tokens (words/subwords) per training example
1123
+ - Higher = more context but requires more GPU memory
1124
+ - Standard: 512-2048, Maximum: 6000 (for long documents)
1125
+ - **Recommendation**: Start with GPU-recommended value
1126
+
1127
+ **Number of Epochs**
1128
+ - How many complete passes through your dataset
1129
+ - More epochs = better learning but risk overfitting
1130
+ - Standard: 3-5 epochs
1131
+ - Watch training loss to avoid overfitting
1132
+
1133
+ **Batch Size**
1134
+ - Number of examples processed simultaneously
1135
+ - Larger = faster training but more memory
1136
+ - Limited by your GPU memory
1137
+ - **GPU-based recommendations provided automatically**
1138
+
1139
+ **Learning Rate**
1140
+ - Controls how quickly the model adapts
1141
+ - Too high = unstable training, too low = slow convergence
1142
+ - Standard: 1e-5 to 5e-4
1143
+ - Default 5e-5 works well for most cases
1144
+
1145
+ **LoRA Rank (r)**
1146
+ - Rank of low-rank adaptation matrices
1147
+ - Higher = more model capacity but slower training
1148
+ - Standard: 8-32
1149
+ - Use lower values for smaller datasets
1150
+
1151
+ **LoRA Alpha**
1152
+ - Scaling factor for LoRA updates
1153
+ - Typically set to 2× the rank
1154
+ - Controls strength of fine-tuning adaptations
1155
+
1156
+ ### 🌐 API Hosting
1157
+
1158
+ #### Model Sources
1159
+
1160
+ **Local Model**: Models saved on your machine
1161
+ - Fine-tuned models from training
1162
+ - Downloaded HuggingFace models
1163
+
1164
+ **HuggingFace Model**: Direct from HuggingFace Hub
1165
+ - Enter model ID (e.g., `mistralai/Mistral-7B-v0.1`)
1166
+ - No need to download first
1167
+ - Automatically cached after first use
1168
+
1169
+ #### API Endpoints
1170
+
1171
+ Once running, access these endpoints:
1172
+ - **Generate**: `POST http://localhost:8000/api/generate`
1173
+ - **Health**: `GET http://localhost:8000/health`
1174
+ - **Docs**: `http://localhost:8000/docs` (Interactive API docs)
1175
+
1176
+ ### 🧪 Testing Inference
1177
+
1178
+ #### Model Selection
1179
+
1180
+ - **Local Model**: Use models from your filesystem
1181
+ - **HuggingFace Model**: Test any model from HuggingFace Hub
1182
+
1183
+ #### Generation Parameters
1184
+
1185
+ **Max Length**
1186
+ - Maximum number of tokens to generate
1187
+ - Higher = longer responses but slower generation
1188
+ - Balance between quality and speed
1189
+ - Typical: 256-1024 for most tasks
1190
+
1191
+ **Temperature**
1192
+ - Controls randomness in generation
1193
+ - **0.1-0.3**: Very focused, deterministic (good for factual tasks)
1194
+ - **0.5-0.7**: Balanced creativity (default, recommended)
1195
+ - **0.8-1.0**: Creative, diverse outputs
1196
+ - **1.0+**: Very random (experimental, often incoherent)
1197
+
1198
+ ### 💡 Tips & Best Practices
1199
+
1200
+ #### GPU Memory Management
1201
+ - **Out of Memory?** Reduce batch size or max sequence length
1202
+ - Monitor GPU usage with `nvidia-smi`
1203
+ - Use gradient checkpointing for very long sequences
1204
+
1205
+ #### Training Tips
1206
+ - **Start Small**: Test with a small subset first
1207
+ - **Monitor Loss**: Should decrease steadily
1208
+ - **Early Stopping**: Stop if validation loss increases
1209
+ - **Save Checkpoints**: Training saves to output directory
1210
+
1211
+ #### Dataset Quality
1212
+ - **Format Consistency**: Ensure all examples follow the same format
1213
+ - **Quality over Quantity**: 1000 good examples > 10000 poor ones
1214
+ - **Diverse Examples**: Cover different aspects of your task
1215
+
1216
+ #### Model Selection
1217
+ - **Base Model**: Start with Mistral-7B-v0.1 (good balance)
1218
+ - **Fine-tuned Models**: Use domain-specific if available
1219
+ - **Test First**: Always test inference before production
1220
+
1221
+ ### 🔧 Dataset Format
1222
+
1223
+ Your training data should be in JSONL format (one JSON object per line):
1224
+
1225
+ **Format 1: Instruction-Response**
1226
+ ```json
1227
+ {"instruction": "Your question or task", "response": "Expected answer"}
1228
+ ```
1229
+
1230
+ **Format 2: Prompt-Completion**
1231
+ ```json
1232
+ {"prompt": "Your question", "completion": "Expected answer"}
1233
+ ```
1234
+
1235
+ **Format 3: Chat Format**
1236
+ ```json
1237
+ {"messages": [
1238
+ {"role": "user", "content": "Question"},
1239
+ {"role": "assistant", "content": "Answer"}
1240
+ ]}
1241
+ ```
1242
+
1243
+ ### 🚨 Troubleshooting
1244
+
1245
+ **Training Issues**
1246
+ - **Out of Memory**: Reduce batch size, max sequence length, or LoRA rank
1247
+ - **Slow Training**: Check GPU utilization, ensure CUDA is available
1248
+ - **NaN Loss**: Reduce learning rate or check data quality
1249
+ - **No Improvement**: Increase epochs, learning rate, or dataset size
1250
+
1251
+ **API Issues**
1252
+ - **Server Won't Start**: Check if port is already in use
1253
+ - **Connection Refused**: Ensure firewall allows the port
1254
+ - **Slow Inference**: Model loading can take time on first request
1255
+ - **Out of Memory**: Model too large for GPU, use smaller model or CPU
1256
+
1257
+ **Model Issues**
1258
+ - **Model Not Found**: Verify path or HuggingFace model ID
1259
+ - **Poor Quality**: May need more training data or epochs
1260
+ - **Inconsistent Output**: Adjust temperature or use lower value
1261
+
1262
+ ### 📊 Performance Benchmarks
1263
+
1264
+ **GPU Memory Requirements (Mistral-7B)**
1265
+ - Training (LoRA): ~12-16GB VRAM
1266
+ - Inference: ~8-10GB VRAM
1267
+ - Batch size 1: minimum required
1268
+ - Batch size 4: optimal on 40GB GPU
1269
+
1270
+ **Training Speed (A100 40GB)**
1271
+ - ~5000 tokens/second
1272
+ - 10k examples: ~30-60 minutes
1273
+ - Depends on sequence length and batch size
1274
+
1275
+ ### 🔄 Recent Updates
1276
+
1277
+ **v2.0 Features**
1278
+ - ✅ File upload for datasets
1279
+ - ✅ HuggingFace dataset integration
1280
+ - ✅ Automatic dataset splitting (train/val/test)
1281
+ - ✅ Extended max sequence length to 6000 tokens
1282
+ - ✅ GPU-specific parameter recommendations
1283
+ - ✅ HuggingFace model support for API hosting
1284
+ - ✅ HuggingFace model support for inference
1285
+ - ✅ Enhanced parameter tooltips and descriptions
1286
+ - ✅ Public URL sharing enabled
1287
+ - ✅ Improved documentation
1288
+
1289
+ ### 📞 Support
1290
+
1291
+ For issues or questions:
1292
+ - Check logs for error messages
1293
+ - Verify GPU availability and memory
1294
+ - Ensure all dependencies are installed
1295
+ - Review dataset format and quality
1296
+ """)
1297
+
1298
+ # Note: Auto-refresh can be enabled with gr.Timer in newer Gradio versions
1299
+ # For now, use the manual refresh buttons to update logs
1300
+
1301
+ return app
1302
+
1303
+ # ==================== MAIN ====================
1304
+
1305
+ def main():
1306
+ """Launch the application"""
1307
+ print("=" * 70)
1308
+ print("🚀 Mistral Fine-Tuning & Hosting Interface v2.0")
1309
+ print("=" * 70)
1310
+ print(f"\n💻 System Information:")
1311
+ print(get_device_info())
1312
+ gpu_rec = get_gpu_recommendations()
1313
+ print(f"\n{gpu_rec['info']}")
1314
+ print(f"\n📁 Base Directory: {BASE_DIR}")
1315
+ print(f"📊 Available Datasets: {len(list_datasets())}")
1316
+ print(f"🤖 Available Models: {len(list_models())}")
1317
+ print("\n" + "=" * 70)
1318
+ print("🌐 Starting web interface...")
1319
+ print("=" * 70 + "\n")
1320
+
1321
+ app = create_interface()
1322
+ app.launch(
1323
+ server_name="0.0.0.0",
1324
+ server_port=7860,
1325
+ share=True,
1326
+ show_error=True
1327
+ )
1328
+
1329
+ if __name__ == "__main__":
1330
+ main()