KShoichi commited on
Commit
1d45303
·
verified ·
1 Parent(s): 01bde61

Upload app/api/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app/api/train.py +64 -0
app/api/train.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from fastapi import APIRouter, UploadFile, File, HTTPException
3
+ from model.train import train_model
4
+ from db.database import SessionLocal
5
+ from db.models import TrainingData, TrainingSession
6
+ import shutil, os
7
+ from uuid import uuid4
8
+ from datetime import datetime
9
+ import pandas as pd
10
+
11
+ tmp_dir = os.path.join(os.path.dirname(__file__), "..", "tmp")
12
+ os.makedirs(tmp_dir, exist_ok=True)
13
+
14
+ router = APIRouter()
15
+
16
+ @router.post("/train")
17
+ def train(file: UploadFile = File(...)):
18
+ session_id = str(uuid4())
19
+ try:
20
+ # Save uploaded file to Windows-friendly tmp dir
21
+ temp_path = os.path.join(tmp_dir, f"{session_id}_{file.filename}")
22
+ with open(temp_path, "wb") as buffer:
23
+ shutil.copyfileobj(file.file, buffer)
24
+ # Validate and load data
25
+ if file.filename.endswith('.csv'):
26
+ df = pd.read_csv(temp_path)
27
+ else:
28
+ df = pd.read_json(temp_path)
29
+ if not all(col in df.columns for col in ["prompt", "response", "question", "is_hallucination"]):
30
+ raise HTTPException(status_code=400, detail="Invalid columns in training data.")
31
+ # Store training data
32
+ db = SessionLocal()
33
+ for _, row in df.iterrows():
34
+ db.add(TrainingData(
35
+ prompt=row["prompt"],
36
+ response=row["response"],
37
+ question=row["question"],
38
+ is_hallucination=row["is_hallucination"]
39
+ ))
40
+ db.commit()
41
+ # Start training
42
+ session = TrainingSession(id=session_id, status="running", training_samples=len(df), started_at=datetime.utcnow(), finished_at=None)
43
+ db.add(session)
44
+ db.commit()
45
+ train_model(temp_path)
46
+ session.status = "success"
47
+ session.finished_at = datetime.utcnow()
48
+ db.commit()
49
+ db.close()
50
+ os.remove(temp_path)
51
+ return {"status": "success", "training_samples": len(df), "session_id": session_id}
52
+ except Exception as e:
53
+ raise HTTPException(status_code=500, detail=str(e))
54
+
55
+ @router.get("/train/status")
56
+ def get_training_status():
57
+ """Get training system status"""
58
+ return {
59
+ "status": "ready",
60
+ "model_loaded": True,
61
+ "training_enabled": True,
62
+ "supported_formats": ["csv", "json"],
63
+ "required_columns": ["prompt", "response", "question", "is_hallucination"]
64
+ }