Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import litserve as ls | |
| from fastapi import Depends, HTTPException | |
| from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class BiasAPI(ls.LitAPI): | |
| def setup(self, device): | |
| model_name = "valurank/distilroberta-bias" | |
| # Load tokenizer and PyTorch model | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| self.model.eval() | |
| # Use GPU if available | |
| self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu")) | |
| self.model = self.model.to(self.device) | |
| def predict_bias(self, texts, return_logits=False): | |
| """Predict bias labels for a string or list of strings""" | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| enc = self.tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| input_ids = enc["input_ids"].to(self.device) | |
| attention_mask = enc["attention_mask"].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| probs = torch.softmax(logits, dim=-1) | |
| results = [] | |
| for i, p in enumerate(probs): | |
| label_id = torch.argmax(p).item() | |
| label_str = self.model.config.id2label[label_id] | |
| score = p[label_id].item() | |
| res = {"label": label_str, "score": score} | |
| if return_logits: | |
| res["logits"] = logits[i].cpu().numpy().tolist() | |
| results.append(res) | |
| return results | |
| def decode_request(self, request): | |
| return request | |
| def predict(self, query): | |
| return self.predict_bias(query) | |
| def encode_response(self, output): | |
| return output | |
| def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())): | |
| if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"): | |
| raise HTTPException(status_code=401, detail="Bad token") | |
| if __name__ == "__main__": | |
| api = BiasAPI() | |
| server = ls.LitServer(api, devices="cpu", accelerator="cpu") | |
| server.run(port=7860) | |