|
|
| from fastapi import APIRouter, UploadFile, File, HTTPException
|
| from model.train import train_model
|
| from db.database import SessionLocal
|
| from db.models import TrainingData, TrainingSession
|
| import shutil, os
|
| from uuid import uuid4
|
| from datetime import datetime
|
| import pandas as pd
|
|
|
| tmp_dir = os.path.join(os.path.dirname(__file__), "..", "tmp")
|
| os.makedirs(tmp_dir, exist_ok=True)
|
|
|
| router = APIRouter()
|
|
|
| @router.post("/train")
|
| def train(file: UploadFile = File(...)):
|
| session_id = str(uuid4())
|
| try:
|
|
|
| temp_path = os.path.join(tmp_dir, f"{session_id}_{file.filename}")
|
| with open(temp_path, "wb") as buffer:
|
| shutil.copyfileobj(file.file, buffer)
|
|
|
| if file.filename.endswith('.csv'):
|
| df = pd.read_csv(temp_path)
|
| else:
|
| df = pd.read_json(temp_path)
|
| if not all(col in df.columns for col in ["prompt", "response", "question", "is_hallucination"]):
|
| raise HTTPException(status_code=400, detail="Invalid columns in training data.")
|
|
|
| db = SessionLocal()
|
| for _, row in df.iterrows():
|
| db.add(TrainingData(
|
| prompt=row["prompt"],
|
| response=row["response"],
|
| question=row["question"],
|
| is_hallucination=row["is_hallucination"]
|
| ))
|
| db.commit()
|
|
|
| session = TrainingSession(id=session_id, status="running", training_samples=len(df), started_at=datetime.utcnow(), finished_at=None)
|
| db.add(session)
|
| db.commit()
|
| train_model(temp_path)
|
| session.status = "success"
|
| session.finished_at = datetime.utcnow()
|
| db.commit()
|
| db.close()
|
| os.remove(temp_path)
|
| return {"status": "success", "training_samples": len(df), "session_id": session_id}
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=str(e))
|
|
|
| @router.get("/train/status")
|
| def get_training_status():
|
| """Get training system status"""
|
| return {
|
| "status": "ready",
|
| "model_loaded": True,
|
| "training_enabled": True,
|
| "supported_formats": ["csv", "json"],
|
| "required_columns": ["prompt", "response", "question", "is_hallucination"]
|
| }
|
|
|