Spaces:
Sleeping
Sleeping
| """ | |
| AutoTrain Mini — minimal FastAPI service for HuggingFace Spaces (CPU basic). | |
| Endpoints: | |
| POST /train { rows:[{text,label}], epochs, base_model } | |
| GET /status | |
| POST /infer { text } | |
| CORS is wide open so the Lovable web app can call directly. | |
| Training runs in a background thread (one job at a time). | |
| """ | |
| import os | |
| import json | |
| import threading | |
| import traceback | |
| from typing import List, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| Trainer, | |
| TrainingArguments, | |
| TrainerCallback, | |
| ) | |
| MODEL_DIR = "/tmp/model" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| STATE = { | |
| "state": "idle", # idle | running | done | error | |
| "progress": 0, | |
| "epoch": 0, | |
| "loss": None, | |
| "accuracy": None, | |
| "message": "", | |
| "logs": [], | |
| "labels": [], | |
| "model_id": None, | |
| } | |
| LOCK = threading.Lock() | |
| _pipeline = None # cached inference pipeline | |
| def log(msg: str): | |
| print(msg, flush=True) | |
| STATE["logs"].append(msg) | |
| if len(STATE["logs"]) > 200: | |
| STATE["logs"] = STATE["logs"][-200:] | |
| class Row(BaseModel): | |
| text: str | |
| label: str | |
| class TrainReq(BaseModel): | |
| rows: List[Row] | |
| epochs: int = 3 | |
| base_model: str = "distilbert-base-uncased" | |
| class InferReq(BaseModel): | |
| text: str | |
| class TextDataset(Dataset): | |
| def __init__(self, encodings, labels): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, i): | |
| item = {k: torch.tensor(v[i]) for k, v in self.encodings.items()} | |
| item["labels"] = torch.tensor(self.labels[i]) | |
| return item | |
| class ProgressCB(TrainerCallback): | |
| def on_log(self, args, state, control, logs=None, **kw): | |
| if not logs: | |
| return | |
| if "loss" in logs: | |
| STATE["loss"] = float(logs["loss"]) | |
| if "epoch" in logs: | |
| STATE["epoch"] = float(logs["epoch"]) | |
| if state.max_steps: | |
| STATE["progress"] = int(100 * state.global_step / state.max_steps) | |
| log(f"step={state.global_step} {logs}") | |
| def do_train(req: TrainReq): | |
| global _pipeline | |
| try: | |
| STATE.update(state="running", progress=0, message="Preparing data", logs=[], loss=None, accuracy=None) | |
| log(f"Rows: {len(req.rows)} base_model={req.base_model} epochs={req.epochs}") | |
| texts = [r.text for r in req.rows] | |
| raw_labels = [r.label for r in req.rows] | |
| label_list = sorted(set(raw_labels)) | |
| label2id = {l: i for i, l in enumerate(label_list)} | |
| id2label = {i: l for l, i in label2id.items()} | |
| STATE["labels"] = label_list | |
| log(f"Labels ({len(label_list)}): {label_list}") | |
| y = [label2id[l] for l in raw_labels] | |
| STATE["message"] = "Loading tokenizer/model" | |
| tok = AutoTokenizer.from_pretrained(req.base_model) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| req.base_model, | |
| num_labels=len(label_list), | |
| id2label=id2label, | |
| label2id=label2id, | |
| ) | |
| STATE["message"] = "Tokenizing" | |
| enc = tok(texts, truncation=True, padding=True, max_length=128) | |
| ds = TextDataset(enc, y) | |
| args = TrainingArguments( | |
| output_dir="/tmp/out", | |
| num_train_epochs=req.epochs, | |
| per_device_train_batch_size=8, | |
| learning_rate=5e-5, | |
| logging_steps=5, | |
| save_strategy="no", | |
| report_to=[], | |
| disable_tqdm=True, | |
| ) | |
| trainer = Trainer(model=model, args=args, train_dataset=ds, callbacks=[ProgressCB()]) | |
| STATE["message"] = "Training" | |
| trainer.train() | |
| STATE["message"] = "Evaluating (train acc)" | |
| preds = trainer.predict(ds) | |
| pred_ids = preds.predictions.argmax(-1) | |
| acc = float((pred_ids == preds.label_ids).mean()) | |
| STATE["accuracy"] = acc | |
| log(f"Train accuracy: {acc:.4f}") | |
| STATE["message"] = "Saving" | |
| model.save_pretrained(MODEL_DIR) | |
| tok.save_pretrained(MODEL_DIR) | |
| with open(os.path.join(MODEL_DIR, "labels.json"), "w") as f: | |
| json.dump(label_list, f) | |
| _pipeline = None # force reload | |
| STATE.update(state="done", progress=100, message="Done", model_id="local") | |
| log("Training complete.") | |
| except Exception as e: | |
| STATE.update(state="error", message=str(e)) | |
| log("ERROR: " + traceback.format_exc()) | |
| def root(): | |
| return {"ok": True, "service": "autotrain-mini", "state": STATE["state"]} | |
| def status(): | |
| return STATE | |
| def train(req: TrainReq): | |
| with LOCK: | |
| if STATE["state"] == "running": | |
| raise HTTPException(409, "A training job is already running") | |
| t = threading.Thread(target=do_train, args=(req,), daemon=True) | |
| t.start() | |
| return {"ok": True, "message": "started"} | |
| def get_pipeline(): | |
| global _pipeline | |
| if _pipeline is not None: | |
| return _pipeline | |
| if not os.path.exists(os.path.join(MODEL_DIR, "config.json")): | |
| return None | |
| from transformers import pipeline | |
| _pipeline = pipeline("text-classification", model=MODEL_DIR, tokenizer=MODEL_DIR, top_k=None) | |
| return _pipeline | |
| def infer(req: InferReq): | |
| pipe = get_pipeline() | |
| if pipe is None: | |
| raise HTTPException(400, "No trained model yet. Call /train first.") | |
| out = pipe(req.text) | |
| return {"input": req.text, "predictions": out} |