| | from flask import Flask, jsonify, request, send_file |
| | import threading |
| | import time |
| | import os |
| | import tempfile |
| | import shutil |
| | import uuid |
| | import zipfile |
| | import io |
| | from datetime import datetime, timedelta |
| |
|
| | app = Flask(__name__) |
| |
|
| | |
| | training_jobs = {} |
| |
|
| | class TrainingProgress: |
| | def __init__(self, job_id): |
| | self.job_id = job_id |
| | self.status = "initializing" |
| | self.progress = 0 |
| | self.current_step = 0 |
| | self.total_steps = 0 |
| | self.start_time = time.time() |
| | self.estimated_finish_time = None |
| | self.message = "Starting training..." |
| | self.error = None |
| | self.model_path = None |
| | self.detected_columns = None |
| |
|
| | def update_progress(self, current_step, total_steps, message=""): |
| | self.current_step = current_step |
| | self.total_steps = total_steps |
| | self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0 |
| | self.message = message |
| | |
| | |
| | if current_step > 0: |
| | elapsed_time = time.time() - self.start_time |
| | time_per_step = elapsed_time / current_step |
| | remaining_steps = total_steps - current_step |
| | estimated_remaining_time = remaining_steps * time_per_step |
| | self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time) |
| |
|
| | def to_dict(self): |
| | return { |
| | "job_id": self.job_id, |
| | "status": self.status, |
| | "progress": round(self.progress, 2), |
| | "current_step": self.current_step, |
| | "total_steps": self.total_steps, |
| | "message": self.message, |
| | "estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None, |
| | "error": self.error, |
| | "model_path": self.model_path, |
| | "detected_columns": self.detected_columns |
| | } |
| |
|
| | def detect_qa_columns(dataset): |
| | """Automatically detect question and answer columns in the dataset""" |
| | |
| | question_patterns = [ |
| | 'question', 'prompt', 'input', 'query', 'patient', 'user', 'human', |
| | 'instruction', 'context', 'q', 'text', 'source' |
| | ] |
| | |
| | |
| | answer_patterns = [ |
| | 'answer', 'response', 'output', 'reply', 'doctor', 'assistant', 'ai', |
| | 'completion', 'target', 'a', 'label', 'ground_truth' |
| | ] |
| | |
| | |
| | columns = list(dataset.column_names) |
| | |
| | |
| | question_col = None |
| | for pattern in question_patterns: |
| | for col in columns: |
| | if pattern.lower() in col.lower(): |
| | question_col = col |
| | break |
| | if question_col: |
| | break |
| | |
| | |
| | answer_col = None |
| | for pattern in answer_patterns: |
| | for col in columns: |
| | if pattern.lower() in col.lower() and col != question_col: |
| | answer_col = col |
| | break |
| | if answer_col: |
| | break |
| | |
| | |
| | if not question_col or not answer_col: |
| | text_columns = [] |
| | for col in columns: |
| | |
| | sample = dataset[0][col] |
| | if isinstance(sample, str) and len(sample.strip()) > 0: |
| | text_columns.append(col) |
| | |
| | if len(text_columns) >= 2: |
| | question_col = text_columns[0] |
| | answer_col = text_columns[1] |
| | elif len(text_columns) == 1: |
| | |
| | question_col = answer_col = text_columns[0] |
| | |
| | return question_col, answer_col |
| |
|
| | def train_model_background(job_id, dataset_name, base_model_name=None): |
| | """Background training function with progress tracking""" |
| | progress = training_jobs[job_id] |
| | |
| | try: |
| | |
| | temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_") |
| | |
| | |
| | os.environ['HF_HOME'] = temp_dir |
| | os.environ['TRANSFORMERS_CACHE'] = temp_dir |
| | os.environ['HF_DATASETS_CACHE'] = temp_dir |
| | os.environ['TORCH_HOME'] = temp_dir |
| | |
| | progress.status = "loading_libraries" |
| | progress.message = "Loading required libraries..." |
| | |
| | |
| | import torch |
| | from datasets import load_dataset, Dataset |
| | from huggingface_hub import login |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | TrainingArguments, |
| | Trainer, |
| | TrainerCallback, |
| | ) |
| | from peft import ( |
| | LoraConfig, |
| | get_peft_model, |
| | ) |
| | |
| | |
| | hf_token = os.getenv('HF_TOKEN') |
| | if hf_token: |
| | login(token=hf_token) |
| | |
| | progress.status = "loading_model" |
| | progress.message = "Loading base model and tokenizer..." |
| |
|
| | |
| | base_model = base_model_name or "microsoft/DialoGPT-small" |
| | new_model = f"trained-model-{job_id}" |
| | max_length = 256 |
| | |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model, |
| | cache_dir=temp_dir, |
| | torch_dtype=torch.float32, |
| | device_map="auto" if torch.cuda.is_available() else "cpu", |
| | trust_remote_code=True |
| | ) |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | base_model, |
| | cache_dir=temp_dir, |
| | trust_remote_code=True |
| | ) |
| | |
| | |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | progress.status = "preparing_model" |
| | progress.message = "Setting up LoRA configuration..." |
| |
|
| | |
| | peft_config = LoraConfig( |
| | r=8, |
| | lora_alpha=16, |
| | lora_dropout=0.1, |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | ) |
| | model = get_peft_model(model, peft_config) |
| |
|
| | progress.status = "loading_dataset" |
| | progress.message = "Loading and preparing dataset..." |
| |
|
| | |
| | dataset = load_dataset( |
| | dataset_name, |
| | split="train" if "train" in load_dataset(dataset_name, cache_dir=temp_dir).keys() else "all", |
| | cache_dir=temp_dir, |
| | trust_remote_code=True |
| | ) |
| | |
| | |
| | question_col, answer_col = detect_qa_columns(dataset) |
| | |
| | if not question_col or not answer_col: |
| | raise ValueError("Could not automatically detect question and answer columns in the dataset") |
| | |
| | progress.detected_columns = {"question": question_col, "answer": answer_col} |
| | progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}" |
| | |
| | |
| | dataset = dataset.shuffle(seed=65).select(range(min(1000, len(dataset)))) |
| |
|
| | |
| | class CustomDataset(torch.utils.data.Dataset): |
| | def __init__(self, texts, tokenizer, max_length): |
| | self.texts = texts |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __len__(self): |
| | return len(self.texts) |
| |
|
| | def __getitem__(self, idx): |
| | text = self.texts[idx] |
| | |
| | |
| | encoding = self.tokenizer( |
| | text, |
| | truncation=True, |
| | padding='max_length', |
| | max_length=self.max_length, |
| | return_tensors='pt' |
| | ) |
| | |
| | |
| | input_ids = encoding['input_ids'].squeeze() |
| | attention_mask = encoding['attention_mask'].squeeze() |
| | |
| | |
| | labels = input_ids.clone() |
| | |
| | |
| | labels[attention_mask == 0] = -100 |
| | |
| | return { |
| | 'input_ids': input_ids, |
| | 'attention_mask': attention_mask, |
| | 'labels': labels |
| | } |
| |
|
| | |
| | texts = [] |
| | for item in dataset: |
| | question = str(item[question_col]).strip() |
| | answer = str(item[answer_col]).strip() |
| | text = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}" |
| | texts.append(text) |
| |
|
| | |
| | train_dataset = CustomDataset(texts, tokenizer, max_length) |
| |
|
| | |
| | batch_size = 2 |
| | gradient_accumulation_steps = 1 |
| | num_epochs = 1 |
| | |
| | steps_per_epoch = len(train_dataset) // (batch_size * gradient_accumulation_steps) |
| | total_steps = steps_per_epoch * num_epochs |
| | |
| | progress.total_steps = total_steps |
| | progress.status = "training" |
| | progress.message = "Starting training..." |
| |
|
| | |
| | output_dir = os.path.join(temp_dir, new_model) |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | training_args = TrainingArguments( |
| | output_dir=output_dir, |
| | per_device_train_batch_size=batch_size, |
| | gradient_accumulation_steps=gradient_accumulation_steps, |
| | num_train_epochs=num_epochs, |
| | logging_steps=1, |
| | save_steps=max(1, total_steps // 2), |
| | save_total_limit=1, |
| | learning_rate=5e-5, |
| | warmup_steps=2, |
| | logging_strategy="steps", |
| | save_strategy="steps", |
| | fp16=False, |
| | bf16=False, |
| | dataloader_num_workers=0, |
| | remove_unused_columns=False, |
| | report_to=None, |
| | prediction_loss_only=True, |
| | ) |
| |
|
| | |
| | class ProgressCallback(TrainerCallback): |
| | def __init__(self, progress_tracker): |
| | self.progress_tracker = progress_tracker |
| | self.last_update = time.time() |
| | |
| | def on_log(self, args, state, control, model=None, logs=None, **kwargs): |
| | current_time = time.time() |
| | |
| | if current_time - self.last_update >= 3: |
| | self.progress_tracker.update_progress( |
| | state.global_step, |
| | state.max_steps, |
| | f"Training step {state.global_step}/{state.max_steps}" |
| | ) |
| | self.last_update = current_time |
| | |
| | |
| | if logs: |
| | loss = logs.get('train_loss', logs.get('loss', 'N/A')) |
| | self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}" |
| | |
| | def on_train_begin(self, args, state, control, **kwargs): |
| | self.progress_tracker.status = "training" |
| | self.progress_tracker.message = "Training started..." |
| | |
| | def on_train_end(self, args, state, control, **kwargs): |
| | self.progress_tracker.status = "saving" |
| | self.progress_tracker.message = "Training complete, saving model..." |
| |
|
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | callbacks=[ProgressCallback(progress)], |
| | tokenizer=tokenizer, |
| | ) |
| |
|
| | |
| | trainer.train() |
| | trainer.save_model(output_dir) |
| | tokenizer.save_pretrained(output_dir) |
| | |
| | |
| | progress.model_path = output_dir |
| | progress.status = "completed" |
| | progress.progress = 100 |
| | progress.message = f"Training completed! Model ready for download." |
| | |
| | |
| | def cleanup_temp_dir(): |
| | time.sleep(3600) |
| | try: |
| | shutil.rmtree(temp_dir) |
| | |
| | if job_id in training_jobs: |
| | del training_jobs[job_id] |
| | except: |
| | pass |
| | |
| | cleanup_thread = threading.Thread(target=cleanup_temp_dir) |
| | cleanup_thread.daemon = True |
| | cleanup_thread.start() |
| | |
| | except Exception as e: |
| | progress.status = "error" |
| | progress.error = str(e) |
| | progress.message = f"Training failed: {str(e)}" |
| | |
| | |
| | try: |
| | if 'temp_dir' in locals(): |
| | shutil.rmtree(temp_dir) |
| | except: |
| | pass |
| |
|
| | def create_model_zip(model_path, job_id): |
| | """Create a zip file containing the trained model""" |
| | memory_file = io.BytesIO() |
| | |
| | with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf: |
| | for root, dirs, files in os.walk(model_path): |
| | for file in files: |
| | file_path = os.path.join(root, file) |
| | arc_name = os.path.relpath(file_path, model_path) |
| | zf.write(file_path, arc_name) |
| | |
| | memory_file.seek(0) |
| | return memory_file |
| | |
| | @app.route('/api/train', methods=['POST']) |
| | def start_training(): |
| | """Start training and return job ID for tracking""" |
| | try: |
| | data = request.get_json() if request.is_json else {} |
| | dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot') |
| | base_model_name = data.get('base_model', 'microsoft/DialoGPT-small') |
| | |
| | job_id = str(uuid.uuid4())[:8] |
| | progress = TrainingProgress(job_id) |
| | training_jobs[job_id] = progress |
| | |
| | |
| | training_thread = threading.Thread( |
| | target=train_model_background, |
| | args=(job_id, dataset_name, base_model_name) |
| | ) |
| | training_thread.daemon = True |
| | training_thread.start() |
| | |
| | return jsonify({ |
| | "status": "started", |
| | "job_id": job_id, |
| | "dataset_name": dataset_name, |
| | "base_model": base_model_name, |
| | "message": "Training started. Use /api/status/<job_id> to track progress." |
| | }) |
| | |
| | except Exception as e: |
| | return jsonify({"status": "error", "message": str(e)}), 500 |
| |
|
| | @app.route('/api/status/<job_id>', methods=['GET']) |
| | def get_training_status(job_id): |
| | """Get training progress and estimated completion time""" |
| | if job_id not in training_jobs: |
| | return jsonify({"status": "error", "message": "Job not found"}), 404 |
| | |
| | progress = training_jobs[job_id] |
| | return jsonify(progress.to_dict()) |
| |
|
| | @app.route('/api/download/<job_id>', methods=['GET']) |
| | def download_model(job_id): |
| | """Download the trained model as a zip file""" |
| | if job_id not in training_jobs: |
| | return jsonify({"status": "error", "message": "Job not found"}), 404 |
| | |
| | progress = training_jobs[job_id] |
| | |
| | if progress.status != "completed": |
| | return jsonify({ |
| | "status": "error", |
| | "message": f"Model not ready for download. Current status: {progress.status}" |
| | }), 400 |
| | |
| | if not progress.model_path or not os.path.exists(progress.model_path): |
| | return jsonify({ |
| | "status": "error", |
| | "message": "Model files not found. They may have been cleaned up." |
| | }), 404 |
| | |
| | try: |
| | |
| | zip_file = create_model_zip(progress.model_path, job_id) |
| | |
| | return send_file( |
| | zip_file, |
| | as_attachment=True, |
| | download_name=f"trained_model_{job_id}.zip", |
| | mimetype='application/zip' |
| | ) |
| | |
| | except Exception as e: |
| | return jsonify({"status": "error", "message": f"Download failed: {str(e)}"}), 500 |
| |
|
| | @app.route('/api/jobs', methods=['GET']) |
| | def list_jobs(): |
| | """List all training jobs""" |
| | jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()} |
| | return jsonify({"jobs": jobs}) |
| |
|
| | @app.route('/') |
| | def home(): |
| | return jsonify({ |
| | "message": "Welcome to Enhanced LLaMA Fine-tuning API!", |
| | "features": [ |
| | "Automatic question/answer column detection", |
| | "Configurable base model and dataset", |
| | "Local model download", |
| | "Progress tracking with ETA" |
| | ], |
| | "endpoints": { |
| | "POST /api/train": "Start training (accepts dataset_name and base_model in JSON)", |
| | "GET /api/status/<job_id>": "Get training status and detected columns", |
| | "GET /api/download/<job_id>": "Download trained model as zip", |
| | "GET /api/jobs": "List all jobs" |
| | }, |
| | "usage_example": { |
| | "start_training": { |
| | "method": "POST", |
| | "url": "/api/train", |
| | "body": { |
| | "dataset_name": "your-dataset-name", |
| | "base_model": "microsoft/DialoGPT-small" |
| | } |
| | } |
| | } |
| | }) |
| |
|
| | @app.route('/health') |
| | def health(): |
| | return jsonify({"status": "healthy"}) |
| |
|
| | if __name__ == '__main__': |
| | port = int(os.environ.get('PORT', 7860)) |
| | app.run(host='0.0.0.0', port=port, debug=False) |