Spaces:
Sleeping
Sleeping
| import os | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Union | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| # Definition of Pydantic data models | |
| class ProblematicItem(BaseModel): | |
| text: str | |
| class ProblematicList(BaseModel): | |
| problematics: List[str] | |
| class PredictionResponse(BaseModel): | |
| predicted_class: str | |
| score: float | |
| class PredictionsResponse(BaseModel): | |
| results: List[Dict[str, Union[str, float]]] | |
| class BatchPredictionScoreItem(BaseModel): | |
| problematic: str | |
| score: float | |
| # Model environment variables | |
| MODEL_NAME = os.getenv("MODEL_NAME") | |
| LABEL_0 = os.getenv("LABEL_0") | |
| LABEL_1 = os.getenv("LABEL_1") | |
| if not MODEL_NAME: | |
| raise ValueError("Environment variable MODEL_NAME is not set.") | |
| # Loading the model and tokenizer | |
| tokenizer = None | |
| model = None | |
| def load_model(): | |
| global tokenizer, model | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return False | |
| def health_check(): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| success = load_model() | |
| if not success: | |
| print("Model not available") | |
| return {"status": "ok", "model": MODEL_NAME} | |
| def predict_single(item: ProblematicItem): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| success = load_model() | |
| if not success: | |
| print('Error loading the model.') | |
| try: | |
| # Tokenization | |
| inputs = tokenizer(item.text, padding=True, truncation=True, return_tensors="pt") | |
| # Prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_class = torch.argmax(probabilities, dim=1).item() | |
| confidence_score = probabilities[0][predicted_class].item() | |
| # Associate the correct label | |
| predicted_label = LABEL_0 if predicted_class == 0 else LABEL_1 | |
| return PredictionResponse(predicted_class=predicted_label, score=confidence_score) | |
| except Exception as e: | |
| print(f"Error during prediction: {str(e)}") | |
| def predict_batch(items: ProblematicList): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| success = load_model() | |
| if not success: | |
| print("Model not available") | |
| try: | |
| results = [] | |
| if not items.problematics: | |
| return [] | |
| # Batch processing | |
| batch_size = 8 | |
| for i in range(0, len(items.problematics), batch_size): | |
| batch_texts = items.problematics[i:i+batch_size] | |
| # Tokenization | |
| inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt") | |
| # Prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # predicted_classes = torch.argmax(probabilities, dim=1).tolist() | |
| # confidence_scores = [probabilities[j][predicted_classes[j]].item() for j in range(len(predicted_classes))] | |
| for j in range(len(batch_texts)): | |
| score_specific_class = probabilities[j][1].item() | |
| results.append( | |
| BatchPredictionScoreItem( | |
| problematic=batch_texts[j], | |
| score=score_specific_class | |
| ) | |
| ) | |
| return results | |
| except AttributeError as ae: | |
| print(f"AttributeError during prediction in predict_batch (likely wrong input type): {str(ae)}") | |
| return [] | |
| except Exception as e: | |
| print(f"Generic error during prediction in predict_batch: {str(e)}") | |
| return [] |