import os from transformers import AutoTokenizer, AutoModelForSequenceClassification from fastapi import FastAPI from pydantic import BaseModel import torch import os # 🧱 Set all possible cache directories to writable locations os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets" os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface" # prevents /.cache access # Ensure directory exists os.makedirs("/tmp/huggingface/transformers", exist_ok=True) # Initialize FastAPI app = FastAPI(title="Check-ins Classifier API", version="1.0") # Load model and tokenizer MODEL_NAME = "mjpsm/check-ins-classifier" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) model.eval() # Label mapping id2label = { 0: "Bad", 1: "Mediocre", 2: "Good" } # Input schema class InputText(BaseModel): text: str @app.post("/predict") async def predict(data: InputText): inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_label_id = torch.argmax(probs, dim=-1).item() return { "input_text": data.text, "predicted_label": id2label[predicted_label_id], "label_id": predicted_label_id, "probabilities": probs.tolist() } @app.get("/") async def home(): return {"message": "Welcome to the Check-ins Classifier API. Use POST /predict to classify text."}