KShoichi's picture
Upload app/api/train.py with huggingface_hub
1d45303 verified
from fastapi import APIRouter, UploadFile, File, HTTPException
from model.train import train_model
from db.database import SessionLocal
from db.models import TrainingData, TrainingSession
import shutil, os
from uuid import uuid4
from datetime import datetime
import pandas as pd
tmp_dir = os.path.join(os.path.dirname(__file__), "..", "tmp")
os.makedirs(tmp_dir, exist_ok=True)
router = APIRouter()
@router.post("/train")
def train(file: UploadFile = File(...)):
session_id = str(uuid4())
try:
# Save uploaded file to Windows-friendly tmp dir
temp_path = os.path.join(tmp_dir, f"{session_id}_{file.filename}")
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Validate and load data
if file.filename.endswith('.csv'):
df = pd.read_csv(temp_path)
else:
df = pd.read_json(temp_path)
if not all(col in df.columns for col in ["prompt", "response", "question", "is_hallucination"]):
raise HTTPException(status_code=400, detail="Invalid columns in training data.")
# Store training data
db = SessionLocal()
for _, row in df.iterrows():
db.add(TrainingData(
prompt=row["prompt"],
response=row["response"],
question=row["question"],
is_hallucination=row["is_hallucination"]
))
db.commit()
# Start training
session = TrainingSession(id=session_id, status="running", training_samples=len(df), started_at=datetime.utcnow(), finished_at=None)
db.add(session)
db.commit()
train_model(temp_path)
session.status = "success"
session.finished_at = datetime.utcnow()
db.commit()
db.close()
os.remove(temp_path)
return {"status": "success", "training_samples": len(df), "session_id": session_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/train/status")
def get_training_status():
"""Get training system status"""
return {
"status": "ready",
"model_loaded": True,
"training_enabled": True,
"supported_formats": ["csv", "json"],
"required_columns": ["prompt", "response", "question", "is_hallucination"]
}