""" AutoTrain Mini — minimal FastAPI service for HuggingFace Spaces (CPU basic). Endpoints: POST /train { rows:[{text,label}], epochs, base_model } GET /status POST /infer { text } CORS is wide open so the Lovable web app can call directly. Training runs in a background thread (one job at a time). """ import os import json import threading import traceback from typing import List, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from torch.utils.data import Dataset from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, TrainerCallback, ) MODEL_DIR = "/tmp/model" os.makedirs(MODEL_DIR, exist_ok=True) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) STATE = { "state": "idle", # idle | running | done | error "progress": 0, "epoch": 0, "loss": None, "accuracy": None, "message": "", "logs": [], "labels": [], "model_id": None, } LOCK = threading.Lock() _pipeline = None # cached inference pipeline def log(msg: str): print(msg, flush=True) STATE["logs"].append(msg) if len(STATE["logs"]) > 200: STATE["logs"] = STATE["logs"][-200:] class Row(BaseModel): text: str label: str class TrainReq(BaseModel): rows: List[Row] epochs: int = 3 base_model: str = "distilbert-base-uncased" class InferReq(BaseModel): text: str class TextDataset(Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __len__(self): return len(self.labels) def __getitem__(self, i): item = {k: torch.tensor(v[i]) for k, v in self.encodings.items()} item["labels"] = torch.tensor(self.labels[i]) return item class ProgressCB(TrainerCallback): def on_log(self, args, state, control, logs=None, **kw): if not logs: return if "loss" in logs: STATE["loss"] = float(logs["loss"]) if "epoch" in logs: STATE["epoch"] = float(logs["epoch"]) if state.max_steps: STATE["progress"] = int(100 * state.global_step / state.max_steps) log(f"step={state.global_step} {logs}") def do_train(req: TrainReq): global _pipeline try: STATE.update(state="running", progress=0, message="Preparing data", logs=[], loss=None, accuracy=None) log(f"Rows: {len(req.rows)} base_model={req.base_model} epochs={req.epochs}") texts = [r.text for r in req.rows] raw_labels = [r.label for r in req.rows] label_list = sorted(set(raw_labels)) label2id = {l: i for i, l in enumerate(label_list)} id2label = {i: l for l, i in label2id.items()} STATE["labels"] = label_list log(f"Labels ({len(label_list)}): {label_list}") y = [label2id[l] for l in raw_labels] STATE["message"] = "Loading tokenizer/model" tok = AutoTokenizer.from_pretrained(req.base_model) model = AutoModelForSequenceClassification.from_pretrained( req.base_model, num_labels=len(label_list), id2label=id2label, label2id=label2id, ) STATE["message"] = "Tokenizing" enc = tok(texts, truncation=True, padding=True, max_length=128) ds = TextDataset(enc, y) args = TrainingArguments( output_dir="/tmp/out", num_train_epochs=req.epochs, per_device_train_batch_size=8, learning_rate=5e-5, logging_steps=5, save_strategy="no", report_to=[], disable_tqdm=True, ) trainer = Trainer(model=model, args=args, train_dataset=ds, callbacks=[ProgressCB()]) STATE["message"] = "Training" trainer.train() STATE["message"] = "Evaluating (train acc)" preds = trainer.predict(ds) pred_ids = preds.predictions.argmax(-1) acc = float((pred_ids == preds.label_ids).mean()) STATE["accuracy"] = acc log(f"Train accuracy: {acc:.4f}") STATE["message"] = "Saving" model.save_pretrained(MODEL_DIR) tok.save_pretrained(MODEL_DIR) with open(os.path.join(MODEL_DIR, "labels.json"), "w") as f: json.dump(label_list, f) _pipeline = None # force reload STATE.update(state="done", progress=100, message="Done", model_id="local") log("Training complete.") except Exception as e: STATE.update(state="error", message=str(e)) log("ERROR: " + traceback.format_exc()) @app.get("/") def root(): return {"ok": True, "service": "autotrain-mini", "state": STATE["state"]} @app.get("/status") def status(): return STATE @app.post("/train") def train(req: TrainReq): with LOCK: if STATE["state"] == "running": raise HTTPException(409, "A training job is already running") t = threading.Thread(target=do_train, args=(req,), daemon=True) t.start() return {"ok": True, "message": "started"} def get_pipeline(): global _pipeline if _pipeline is not None: return _pipeline if not os.path.exists(os.path.join(MODEL_DIR, "config.json")): return None from transformers import pipeline _pipeline = pipeline("text-classification", model=MODEL_DIR, tokenizer=MODEL_DIR, top_k=None) return _pipeline @app.post("/infer") def infer(req: InferReq): pipe = get_pipeline() if pipe is None: raise HTTPException(400, "No trained model yet. Call /train first.") out = pipe(req.text) return {"input": req.text, "predictions": out}