Auto-train / app.py
Toilatop1sever's picture
Upload 3 files
3ccdfc1 verified
Raw
History Blame Contribute Delete
5.84 kB
"""
AutoTrain Mini — minimal FastAPI service for HuggingFace Spaces (CPU basic).
Endpoints:
POST /train { rows:[{text,label}], epochs, base_model }
GET /status
POST /infer { text }
CORS is wide open so the Lovable web app can call directly.
Training runs in a background thread (one job at a time).
"""
import os
import json
import threading
import traceback
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from torch.utils.data import Dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
TrainerCallback,
)
MODEL_DIR = "/tmp/model"
os.makedirs(MODEL_DIR, exist_ok=True)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
STATE = {
"state": "idle", # idle | running | done | error
"progress": 0,
"epoch": 0,
"loss": None,
"accuracy": None,
"message": "",
"logs": [],
"labels": [],
"model_id": None,
}
LOCK = threading.Lock()
_pipeline = None # cached inference pipeline
def log(msg: str):
print(msg, flush=True)
STATE["logs"].append(msg)
if len(STATE["logs"]) > 200:
STATE["logs"] = STATE["logs"][-200:]
class Row(BaseModel):
text: str
label: str
class TrainReq(BaseModel):
rows: List[Row]
epochs: int = 3
base_model: str = "distilbert-base-uncased"
class InferReq(BaseModel):
text: str
class TextDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, i):
item = {k: torch.tensor(v[i]) for k, v in self.encodings.items()}
item["labels"] = torch.tensor(self.labels[i])
return item
class ProgressCB(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kw):
if not logs:
return
if "loss" in logs:
STATE["loss"] = float(logs["loss"])
if "epoch" in logs:
STATE["epoch"] = float(logs["epoch"])
if state.max_steps:
STATE["progress"] = int(100 * state.global_step / state.max_steps)
log(f"step={state.global_step} {logs}")
def do_train(req: TrainReq):
global _pipeline
try:
STATE.update(state="running", progress=0, message="Preparing data", logs=[], loss=None, accuracy=None)
log(f"Rows: {len(req.rows)} base_model={req.base_model} epochs={req.epochs}")
texts = [r.text for r in req.rows]
raw_labels = [r.label for r in req.rows]
label_list = sorted(set(raw_labels))
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for l, i in label2id.items()}
STATE["labels"] = label_list
log(f"Labels ({len(label_list)}): {label_list}")
y = [label2id[l] for l in raw_labels]
STATE["message"] = "Loading tokenizer/model"
tok = AutoTokenizer.from_pretrained(req.base_model)
model = AutoModelForSequenceClassification.from_pretrained(
req.base_model,
num_labels=len(label_list),
id2label=id2label,
label2id=label2id,
)
STATE["message"] = "Tokenizing"
enc = tok(texts, truncation=True, padding=True, max_length=128)
ds = TextDataset(enc, y)
args = TrainingArguments(
output_dir="/tmp/out",
num_train_epochs=req.epochs,
per_device_train_batch_size=8,
learning_rate=5e-5,
logging_steps=5,
save_strategy="no",
report_to=[],
disable_tqdm=True,
)
trainer = Trainer(model=model, args=args, train_dataset=ds, callbacks=[ProgressCB()])
STATE["message"] = "Training"
trainer.train()
STATE["message"] = "Evaluating (train acc)"
preds = trainer.predict(ds)
pred_ids = preds.predictions.argmax(-1)
acc = float((pred_ids == preds.label_ids).mean())
STATE["accuracy"] = acc
log(f"Train accuracy: {acc:.4f}")
STATE["message"] = "Saving"
model.save_pretrained(MODEL_DIR)
tok.save_pretrained(MODEL_DIR)
with open(os.path.join(MODEL_DIR, "labels.json"), "w") as f:
json.dump(label_list, f)
_pipeline = None # force reload
STATE.update(state="done", progress=100, message="Done", model_id="local")
log("Training complete.")
except Exception as e:
STATE.update(state="error", message=str(e))
log("ERROR: " + traceback.format_exc())
@app.get("/")
def root():
return {"ok": True, "service": "autotrain-mini", "state": STATE["state"]}
@app.get("/status")
def status():
return STATE
@app.post("/train")
def train(req: TrainReq):
with LOCK:
if STATE["state"] == "running":
raise HTTPException(409, "A training job is already running")
t = threading.Thread(target=do_train, args=(req,), daemon=True)
t.start()
return {"ok": True, "message": "started"}
def get_pipeline():
global _pipeline
if _pipeline is not None:
return _pipeline
if not os.path.exists(os.path.join(MODEL_DIR, "config.json")):
return None
from transformers import pipeline
_pipeline = pipeline("text-classification", model=MODEL_DIR, tokenizer=MODEL_DIR, top_k=None)
return _pipeline
@app.post("/infer")
def infer(req: InferReq):
pipe = get_pipeline()
if pipe is None:
raise HTTPException(400, "No trained model yet. Call /train first.")
out = pipe(req.text)
return {"input": req.text, "predictions": out}