Spaces:
Sleeping
Sleeping
| """ | |
| Inference pipeline for DistilBERT sentiment analysis | |
| File: infer.py (improved version) | |
| """ | |
| import torch | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Global variables to cache model and tokenizer | |
| _model = None | |
| _tokenizer = None | |
| def load_trained_model(model_path="./model"): | |
| """Load saved model and tokenizer (cached)""" | |
| global _model, _tokenizer | |
| # Check if model exists | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"No model found at {model_path}. Please train the model first.") | |
| # Return cached model if already loaded | |
| if _model is not None and _tokenizer is not None: | |
| return _model, _tokenizer | |
| print(f"Loading model from {model_path}...") | |
| _tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| _model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| print("Model loaded successfully!") | |
| return _model, _tokenizer | |
| def predict_sentiment(text, model, tokenizer, max_length=256): | |
| """ | |
| Predict sentiment for a single text | |
| Args: | |
| text: Input text string | |
| model: Loaded model | |
| tokenizer: Loaded tokenizer | |
| max_length: Max sequence length | |
| Returns: | |
| Tuple of (predicted_label, confidence_score) | |
| """ | |
| # Tokenize input | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=max_length | |
| ) | |
| # Get prediction | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_class = torch.argmax(predictions, dim=-1).item() | |
| confidence = predictions[0][predicted_class].item() | |
| # Convert to readable format | |
| label = "Positive" if predicted_class == 1 else "Negative" | |
| return label, confidence | |
| def predict(text, model_path="./model", max_length=256): | |
| """ | |
| Simple prediction function for new text | |
| Args: | |
| text: Input text string | |
| model_path: Path to saved model | |
| max_length: Max sequence length | |
| Returns: | |
| String: "positive" or "negative" | |
| """ | |
| try: | |
| # Load model and tokenizer (cached) | |
| model, tokenizer = load_trained_model(model_path) | |
| # Tokenize input | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=max_length | |
| ) | |
| # Get prediction | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predicted_class = torch.argmax(outputs.logits, dim=-1).item() | |
| return "positive" if predicted_class == 1 else "negative" | |
| except FileNotFoundError as e: | |
| return f"Error: {str(e)}" | |
| except Exception as e: | |
| return f"Prediction error: {str(e)}" |