Spaces:
Sleeping
Sleeping
| import os | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Union | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| # Definition of Pydantic data models | |
| class ProblematicItem(BaseModel): | |
| text: str | |
| class ProblematicList(BaseModel): | |
| problematics: List[str] | |
| class PredictionResponse(BaseModel): | |
| predicted_class: str | |
| score: float | |
| class PredictionsResponse(BaseModel): | |
| results: List[Dict[str, Union[str, float]]] | |
| # Model environment variables | |
| MODEL_NAME = os.getenv("MODEL_NAME", "votre-compte/votre-modele") | |
| LABEL_0 = os.getenv("LABEL_0", "Classe A") | |
| LABEL_1 = os.getenv("LABEL_1", "Classe B") | |
| # Loading the model and tokenizer | |
| tokenizer = None | |
| model = None | |
| def load_model(): | |
| global tokenizer, model | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| def health_check(): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| success = load_model() | |
| if not success: | |
| raise HTTPException(status_code=503, detail="Model not available") | |
| return {"status": "ok", "model": MODEL_NAME} | |
| def predict_single(item: ProblematicItem): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| success = load_model() | |
| if not success: | |
| print('Error loading the model.') | |
| try: | |
| # Tokenization | |
| inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt") | |
| # Prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_class = torch.argmax(probabilities, dim=1).item() | |
| confidence_score = probabilities[0][predicted_class].item() | |
| # Associate the correct label | |
| predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1 | |
| return PredictionResponse(predicted_class=predicted_label, score=confidence_score) | |
| except Exception as e: | |
| print(f"Error during prediction: {str(e)}") | |
| def predict_batch(items: ProblematicList): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| success = load_model() | |
| if not success: | |
| print("Model not available") | |
| try: | |
| results = [] | |
| # Batch processing | |
| batch_size = 8 | |
| for i in range(0, len(items.problematics), batch_size): | |
| batch_texts = items.problematics[i:i+batch_size] | |
| # Tokenization | |
| inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt") | |
| # Prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_classes = torch.argmax(probabilities, dim=1).tolist() | |
| confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))] | |
| # Converting numerical predictions into labels | |
| for j, (pred_class, score) in enumerate(zip(predicted_classes, confidence_scores)): | |
| predicted_label = LABEL_0 if pred_class == 0 else LABEL_1 | |
| results.append({ | |
| "text": batch_texts[j], | |
| "class": predicted_label, | |
| "score": score | |
| }) | |
| return PredictionsResponse(results=results) | |
| except Exception as e: | |
| print(f"Error during prediction: {str(e)}") |