check-in-API / app.py
mjpsm's picture
Update app.py
e0d585f verified
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import os
# 🧱 Set all possible cache directories to writable locations
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface/datasets"
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface" # prevents /.cache access
# Ensure directory exists
os.makedirs("/tmp/huggingface/transformers", exist_ok=True)
# Initialize FastAPI
app = FastAPI(title="Check-ins Classifier API", version="1.0")
# Load model and tokenizer
MODEL_NAME = "mjpsm/check-ins-classifier"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
# Label mapping
id2label = {
0: "Bad",
1: "Mediocre",
2: "Good"
}
# Input schema
class InputText(BaseModel):
text: str
@app.post("/predict")
async def predict(data: InputText):
inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_label_id = torch.argmax(probs, dim=-1).item()
return {
"input_text": data.text,
"predicted_label": id2label[predicted_label_id],
"label_id": predicted_label_id,
"probabilities": probs.tolist()
}
@app.get("/")
async def home():
return {"message": "Welcome to the Check-ins Classifier API. Use POST /predict to classify text."}