Spaces:
Sleeping
Sleeping
| from typing import Dict, List, Any | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel, AutoConfig | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| class EndpointHandler: | |
| def __init__(self, path: str): | |
| MODEL_REPO = 'bie-nhd/visobert-multitask' | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.tokenizer = AutoTokenizer.from_pretrained('bie-nhd/visobert-multitask') | |
| # config = AutoConfig.from_pretrained('bie-nhd/visobert-multitask') | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.pt") | |
| self.model = AutoModel.from_pretrained("uitnlp/visobert") | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| self.model_state_dict = checkpoint['encoder'] | |
| self.model.load_state_dict(self.model_state_dict) | |
| self.model.eval() | |
| self.task_heads = torch.nn.ModuleDict({ | |
| 'sentiment': TaskClassificationHead(self.model.config.hidden_size, 4, 0.1), | |
| 'topic': TaskClassificationHead(self.model.config.hidden_size, 10, 0.1), | |
| 'hate_speech': TaskClassificationHead(self.model.config.hidden_size, 5, 0.1), | |
| 'clickbait': TaskClassificationHead(self.model.config.hidden_size, 2, 0.1) | |
| }) | |
| self.log_vars = torch.nn.ParameterDict({ | |
| task: torch.nn.Parameter(torch.zeros(1)) for task in self.task_heads | |
| }) | |
| self.log_vars.load_state_dict(checkpoint['log_vars']) | |
| self.model.to(self.device) | |
| self.task_heads.to(self.device) | |
| self.log_vars.to(self.device) | |
| self.task_config = { | |
| 'sentiment': { | |
| 'num_labels': 4, | |
| 'type': 'single_label', | |
| 'label_map': {0: 'Neutral', 1: 'Positive', 2: 'Negative', 3: 'Toxic'} | |
| }, | |
| 'topic': { | |
| 'num_labels': 10, | |
| 'type': 'single_label', | |
| 'label_map': {i: label for i, label in enumerate(['Spam', 'News', 'Academic', 'Other', 'Service', 'Jobs', 'Personal', 'Social', 'Help', 'Events'])} | |
| }, | |
| 'hate_speech': { | |
| 'num_labels': 5, | |
| 'type': 'multi_label', | |
| 'label_list': ['individual', 'groups', 'religion/creed', 'race/ethnicity', 'politics'] | |
| }, | |
| 'clickbait': { | |
| 'num_labels': 2, | |
| 'type': 'single_label', | |
| 'label_map': {0: 'Non-Clickbait', 1: 'Clickbait'}, | |
| 'dual_input': True | |
| } | |
| } | |
| def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
| task = inputs.get('task', None) | |
| title = inputs.get('title', None) | |
| text = inputs.get('text', inputs.get('inputs', None)) | |
| if task is None or task not in self.task_config.keys(): | |
| raise ValueError(f"Invalid task: {task}") | |
| config = self.task_config[task] | |
| max_length = 256 if task == 'clickbait' else 128 | |
| encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt') | |
| if config.get('dual_input', False): | |
| encoding = self.tokenizer(f"{title} </s></s> {text}", padding='max_length', truncation=True, max_length=max_length, return_tensors='pt') | |
| return encoding | |
| def predict(self, task:str, preprocessed: Dict[str, Any]) -> Dict[str, Any]: | |
| logits = self.task_heads[task](self.model(**preprocessed).last_hidden_state[:, 0, :]) | |
| config = self.task_config[task] | |
| if config['type'] == 'multi_label': | |
| probs = torch.sigmoid(logits).detach().cpu().numpy()[0] | |
| active = [config['label_list'][i] for i in np.where(probs > 0.5)[0]] | |
| return {'labels': active, 'scores': probs.tolist()} | |
| else: | |
| probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0] | |
| pred_idx = int(np.argmax(probs)) | |
| return {'label': config['label_map'][pred_idx], 'confidence': float(probs[pred_idx])} | |
| # def postprocess(self, outputs: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| # return [outputs] | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| task = data.get('task', None) | |
| print(f"Task: {task}") | |
| if task is None: | |
| raise ValueError("'task' key is required in the input dictionary") | |
| task = task.lower() | |
| results = {} | |
| if task == "all": | |
| for _t in self.task_config.keys(): | |
| data['task'] = _t | |
| preprocessed = self.preprocess(data) | |
| outputs = self.predict(_t, preprocessed) | |
| results[_t] = outputs | |
| return results | |
| elif task not in self.task_config.keys(): | |
| raise ValueError(f"Invalid task: {task}") | |
| preprocessed = self.preprocess(data) | |
| outputs = self.predict(task, preprocessed) | |
| results[task] = outputs | |
| # return self.postprocess(outputs) | |
| return results | |
| class TaskClassificationHead(torch.nn.Module): | |
| def __init__(self, hidden_size: int, num_labels: int, dropout: float): | |
| super().__init__() | |
| bottleneck = max(hidden_size // 2, num_labels) | |
| self.projection = torch.nn.Sequential( | |
| torch.nn.Linear(hidden_size, bottleneck), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(dropout), | |
| torch.nn.Linear(bottleneck, num_labels), | |
| ) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| return self.projection(hidden_states) |