Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
| import torch | |
| app = FastAPI() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load config first | |
| config = AutoConfig.from_pretrained("SrivarshiniGanesan/finetuned-stress-model") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "SrivarshiniGanesan/finetuned-stress-model", | |
| config=config | |
| ).to(device) | |
| tokenizer = AutoTokenizer.from_pretrained("SrivarshiniGanesan/finetuned-stress-model") | |
| def predict(text: str): | |
| try: | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| class_labels = config.id2label if config.id2label else {0: "No Stress", 1: "Stress"} | |
| stress_idx = list(class_labels.values()).index("Stress") | |
| return {"stress_probability": probs[0, stress_idx].item()} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction failed: {str(e)}" | |
| ) | |