GokulRajaR's picture
Update server.py
43852a9 verified
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import litserve as ls
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import os
from dotenv import load_dotenv
load_dotenv()
class BiasAPI(ls.LitAPI):
def setup(self, device):
model_name = "valurank/distilroberta-bias"
# Load tokenizer and PyTorch model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval()
# Use GPU if available
self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
self.model = self.model.to(self.device)
def predict_bias(self, texts, return_logits=False):
"""Predict bias labels for a string or list of strings"""
if isinstance(texts, str):
texts = [texts]
enc = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
input_ids = enc["input_ids"].to(self.device)
attention_mask = enc["attention_mask"].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
results = []
for i, p in enumerate(probs):
label_id = torch.argmax(p).item()
label_str = self.model.config.id2label[label_id]
score = p[label_id].item()
res = {"label": label_str, "score": score}
if return_logits:
res["logits"] = logits[i].cpu().numpy().tolist()
results.append(res)
return results
def decode_request(self, request):
return request
def predict(self, query):
return self.predict_bias(query)
def encode_response(self, output):
return output
def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
raise HTTPException(status_code=401, detail="Bad token")
if __name__ == "__main__":
api = BiasAPI()
server = ls.LitServer(api, devices="cpu", accelerator="cpu")
server.run(port=7860)