Pulastya0's picture
Update app.py
0f5b2b6 verified
raw
history blame
1.82 kB
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
app = FastAPI(title="Routing Service - Space 2")
os.environ["HF_HOME"] = "/data/huggingface-cache"
os.environ["TRANSFORMERS_CACHE"] = "/data/huggingface-cache"
# -------------------------------
# Request Model
# -------------------------------
class RoutingRequest(BaseModel):
text: str
# -------------------------------
# Load Routing Model (DeBERTa MNLI)
# -------------------------------
MODEL_NAME = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# Define your possible departments / labels
DEPARTMENTS = ["Account", "Software", "Network", "Security", "Hardware",
"Infrastructure", "Licensing", "Communication", "RemoteWork",
"Training", "Performance"]
# -------------------------------
# Routing Endpoint
# -------------------------------
@app.post("/route")
async def route_ticket(req: RoutingRequest):
text = req.text
if not text:
raise HTTPException(status_code=400, detail="Text cannot be empty")
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True)
outputs = model(**inputs)
logits = outputs.logits[0]
# Simple mapping: choose max logit index as department (demo)
# For a real hackathon, you may map labels more carefully
department_idx = torch.argmax(logits).item() % len(DEPARTMENTS)
department = DEPARTMENTS[department_idx]
return {"department": department}
# -------------------------------
# Health Check
# -------------------------------
@app.get("/health")
async def health():
return {"status": "ok"}