from fastapi import APIRouter, UploadFile, File, HTTPException from model.train import train_model from db.database import SessionLocal from db.models import TrainingData, TrainingSession import shutil, os from uuid import uuid4 from datetime import datetime import pandas as pd tmp_dir = os.path.join(os.path.dirname(__file__), "..", "tmp") os.makedirs(tmp_dir, exist_ok=True) router = APIRouter() @router.post("/train") def train(file: UploadFile = File(...)): session_id = str(uuid4()) try: # Save uploaded file to Windows-friendly tmp dir temp_path = os.path.join(tmp_dir, f"{session_id}_{file.filename}") with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) # Validate and load data if file.filename.endswith('.csv'): df = pd.read_csv(temp_path) else: df = pd.read_json(temp_path) if not all(col in df.columns for col in ["prompt", "response", "question", "is_hallucination"]): raise HTTPException(status_code=400, detail="Invalid columns in training data.") # Store training data db = SessionLocal() for _, row in df.iterrows(): db.add(TrainingData( prompt=row["prompt"], response=row["response"], question=row["question"], is_hallucination=row["is_hallucination"] )) db.commit() # Start training session = TrainingSession(id=session_id, status="running", training_samples=len(df), started_at=datetime.utcnow(), finished_at=None) db.add(session) db.commit() train_model(temp_path) session.status = "success" session.finished_at = datetime.utcnow() db.commit() db.close() os.remove(temp_path) return {"status": "success", "training_samples": len(df), "session_id": session_id} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/train/status") def get_training_status(): """Get training system status""" return { "status": "ready", "model_loaded": True, "training_enabled": True, "supported_formats": ["csv", "json"], "required_columns": ["prompt", "response", "question", "is_hallucination"] }