Spaces:
Build error
Build error
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from utils.logger import setup_logger | |
| from utils.model_loader import ModelLoader | |
| logger = setup_logger(__name__) | |
| class IntentClassifier: | |
| def __init__(self): | |
| self.model_name = "distilbert-base-uncased-finetuned-sst-2-english" | |
| try: | |
| self.model = ModelLoader.load_model_with_retry( | |
| self.model_name, | |
| AutoModelForSequenceClassification, | |
| num_labels=2 | |
| ) | |
| self.tokenizer = ModelLoader.load_model_with_retry( | |
| self.model_name, | |
| AutoTokenizer | |
| ) | |
| self.intents = {0: "database_query", 1: "product_description"} | |
| except Exception as e: | |
| logger.error(f"Failed to initialize IntentClassifier: {str(e)}") | |
| raise | |
| def classify(self, query): | |
| try: | |
| inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True) | |
| outputs = self.model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_class = torch.argmax(probabilities).item() | |
| return self.intents[predicted_class], probabilities[0][predicted_class].item() | |
| except Exception as e: | |
| logger.error(f"Classification error: {str(e)}") | |
| return "error", 0.0 | |