Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import RobertaTokenizer, RobertaForSequenceClassification | |
| # Standard Categories for our ABSA Model | |
| ASPECT_CATEGORIES = ['food', 'service', 'ambiance', 'price', 'anecdotes/miscellaneous'] | |
| LABELS = {0: 'positive', 1: 'negative', 2: 'neutral', 3: 'conflict'} | |
| class ABSAPredictor: | |
| def __init__(self, model_path: str = 'models/absa-roberta-final'): | |
| """ | |
| Initializes the ABSA predictor by loading the fine-tuned RoBERTa model and tokenizer. | |
| """ | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| try: | |
| self.tokenizer = RobertaTokenizer.from_pretrained(model_path) | |
| self.model = RobertaForSequenceClassification.from_pretrained(model_path) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| self.is_loaded = True | |
| print(f"Model successfully loaded from {model_path} onto {self.device}") | |
| except Exception as e: | |
| print(f"Warning: Could not load model from {model_path}. Error: {e}") | |
| print("The predictor will run in MOCK mode until you place the trained model in the directory.") | |
| self.is_loaded = False | |
| def predict(self, review_text: str, confidence_threshold: float = 0.6) -> dict: | |
| """ | |
| Predicts sentiments for all aspect categories in a single review. | |
| Args: | |
| review_text (str): The raw text of the review. | |
| confidence_threshold (float): Only return aspects where model confidence > threshold. | |
| Returns: | |
| dict: A dictionary mapping aspect -> {'sentiment': str, 'confidence': float} | |
| """ | |
| results = {} | |
| # Fallback Mock Mode if model isn't built yet | |
| if not self.is_loaded: | |
| # Simple mock logic for demonstration before actual model is downloaded | |
| lower_text = review_text.lower() | |
| if 'pizza' in lower_text or 'food' in lower_text or 'delicious' in lower_text: | |
| results['food'] = {'sentiment': 'positive', 'confidence': 0.95} | |
| if 'waiter' in lower_text or 'slow' in lower_text or 'rude' in lower_text: | |
| results['service'] = {'sentiment': 'negative', 'confidence': 0.88} | |
| if 'expensive' in lower_text or 'cheap' in lower_text: | |
| results['price'] = {'sentiment': 'negative' if 'expensive' in lower_text else 'positive', 'confidence': 0.82} | |
| return results | |
| # Actual Inference Loop | |
| with torch.no_grad(): | |
| for aspect in ASPECT_CATEGORIES: | |
| inputs = self.tokenizer( | |
| review_text, | |
| aspect, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=128, | |
| return_tensors='pt' | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| outputs = self.model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| pred_label_idx = torch.argmax(probs).item() | |
| confidence = probs[0][pred_label_idx].item() | |
| # We only say an aspect is "mentioned" if the confidence of *any* sentiment is high enough | |
| # and it's not overwhelmingly 'neutral' with low confidence | |
| if confidence >= confidence_threshold: | |
| results[aspect] = { | |
| 'sentiment': LABELS[pred_label_idx], | |
| 'confidence': round(confidence, 3) | |
| } | |
| return results | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--text', type=str, required=True, help="Review text to analyze") | |
| args = parser.parse_args() | |
| predictor = ABSAPredictor() | |
| results = predictor.predict(args.text) | |
| print(f"\nReview: '{args.text}'") | |
| print("Detected Aspects & Sentiments:") | |
| for aspect, data in results.items(): | |
| print(f" - {aspect.ljust(15)}: {data['sentiment'].ljust(10)} (conf: {data['confidence']})") | |
| if not results: | |
| print(" - No aspects detected above the confidence threshold.") | |