| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | from src.models import ModernBertForSentiment |
| | from transformers import ModernBertConfig |
| | from typing import Dict, Any |
| | import yaml |
| | import os |
| |
|
| |
|
| | class SentimentInference: |
| | def __init__(self, config_path: str = "config.yaml"): |
| | """Load configuration and initialize model and tokenizer.""" |
| | with open(config_path, 'r') as f: |
| | config = yaml.safe_load(f) |
| | |
| | model_cfg = config.get('model', {}) |
| | inference_cfg = config.get('inference', {}) |
| | |
| | |
| | model_weights_path = inference_cfg.get('model_path', |
| | os.path.join(model_cfg.get('output_dir', 'checkpoints'), 'best_model.pt')) |
| | |
| | |
| | |
| | base_model_name = model_cfg.get('name', 'answerdotai/ModernBERT-base') |
| |
|
| | self.max_length = inference_cfg.get('max_length', model_cfg.get('max_length', 256)) |
| |
|
| | |
| | print(f"Loading tokenizer from: {base_model_name}") |
| | self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
| | |
| | |
| | print(f"Loading ModernBertConfig from: {base_model_name}") |
| | bert_config = ModernBertConfig.from_pretrained(base_model_name) |
| | |
| | |
| | |
| | |
| | |
| | bert_config.classifier_dropout = model_cfg.get('dropout', bert_config.classifier_dropout) |
| | |
| | |
| |
|
| | |
| | |
| | |
| | bert_config.pooling_strategy = model_cfg.get('pooling_strategy', 'cls') |
| | bert_config.num_weighted_layers = model_cfg.get('num_weighted_layers', 4) |
| | bert_config.loss_function = model_cfg.get('loss_function', {'name': 'SentimentWeightedLoss', 'params': {}}) |
| | |
| | bert_config.num_labels = 1 |
| |
|
| | print("Instantiating ModernBertForSentiment model structure...") |
| | self.model = ModernBertForSentiment(bert_config) |
| | |
| | print(f"Loading model weights from local checkpoint: {model_weights_path}") |
| | |
| | checkpoint = torch.load(model_weights_path, map_location=torch.device('cpu')) |
| | |
| | |
| | |
| | if 'model_state_dict' in checkpoint: |
| | model_state_to_load = checkpoint['model_state_dict'] |
| | else: |
| | |
| | model_state_to_load = checkpoint |
| | |
| | self.model.load_state_dict(model_state_to_load) |
| | self.model.eval() |
| | print("Model loaded successfully.") |
| | |
| | def predict(self, text: str) -> Dict[str, Any]: |
| | inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length) |
| | with torch.no_grad(): |
| | outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
| | logits = outputs["logits"] |
| | prob = torch.sigmoid(logits).item() |
| | return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob} |