| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig |
| |
| from typing import Dict, Any |
| import yaml |
|
|
| class SentimentInference: |
| def __init__(self, config_path: str = "config.yaml"): |
| """Load configuration and initialize model and tokenizer from Hugging Face Hub.""" |
| with open(config_path, 'r') as f: |
| config_data = yaml.safe_load(f) |
| |
| model_yaml_cfg = config_data.get('model', {}) |
| inference_yaml_cfg = config_data.get('inference', {}) |
| |
| model_hf_repo_id = model_yaml_cfg.get('name_or_path') |
| if not model_hf_repo_id: |
| raise ValueError("model.name_or_path must be specified in config.yaml (e.g., 'username/model_name')") |
|
|
| tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id) |
|
|
| self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512)) |
|
|
| print(f"Loading tokenizer from: {tokenizer_hf_repo_id}") |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_hf_repo_id) |
| |
| print(f"Loading base ModernBertConfig from: {model_hf_repo_id}") |
| |
| |
| |
| loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id) |
| |
| |
| |
| loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean') |
| loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4) |
| loaded_config.classifier_dropout = model_yaml_cfg.get('dropout') |
| |
| |
| loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1) |
| |
| |
| |
| |
|
|
| print(f"Instantiating and loading model weights for {model_hf_repo_id}...") |
| |
| |
| self.model = AutoModelForSequenceClassification.from_pretrained( |
| model_hf_repo_id, |
| config=loaded_config, |
| trust_remote_code=True |
| ) |
| self.model.eval() |
| print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.") |
| |
| def predict(self, text: str) -> Dict[str, Any]: |
| inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True) |
| with torch.no_grad(): |
| outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
| logits = outputs.get("logits") |
| if logits is None: |
| raise ValueError("Model output did not contain 'logits'. Check model's forward pass.") |
| prob = torch.sigmoid(logits).item() |
| return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob} |