Spaces:
Running
Running
| import torch | |
| from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification | |
| # Define label mappings (ensure this matches the mappings used during training) | |
| label2id = {'<negative_object>': 0, 'other': 2, '<positive_object>': 1} | |
| id2label = {v: k for k, v in label2id.items()} | |
| def prepare_input(tokens, tokenizer, max_length=128): | |
| encoding = tokenizer( | |
| tokens, | |
| is_split_into_words=True, | |
| return_tensors="pt", | |
| padding='max_length', | |
| truncation=True, | |
| max_length=max_length, | |
| return_offsets_mapping=True | |
| ) | |
| return encoding | |
| def split_sentence(sentence): | |
| # List of special tokens to preserve | |
| special_tokens = ['<positive_object>', '<negative_object>'] | |
| # More comprehensive list of punctuation marks and symbols | |
| punctuation = ',.?!;:()[]{}""\'`@#$%^&*+=|\\/<>-ββ' | |
| # Initialize result list and temporary word | |
| result = [] | |
| current_word = '' | |
| i = 0 | |
| while i < len(sentence): | |
| # Check for special tokens | |
| found_special = False | |
| for token in special_tokens: | |
| if sentence[i:].startswith(token): | |
| # Add previous word if exists | |
| if current_word: | |
| result.append(current_word) | |
| current_word = '' | |
| # Add special token | |
| result.append(token) | |
| i += len(token) | |
| found_special = True | |
| break | |
| if found_special: | |
| continue | |
| # Handle punctuation | |
| if sentence[i] in punctuation: | |
| # Add previous word if exists | |
| if current_word: | |
| result.append(current_word) | |
| current_word = '' | |
| # Add punctuation as separate token | |
| result.append(sentence[i]) | |
| # Handle spaces | |
| elif sentence[i].isspace(): | |
| if current_word: | |
| result.append(current_word) | |
| current_word = '' | |
| # Build regular words | |
| else: | |
| current_word += sentence[i] | |
| i += 1 | |
| # Add final word if exists | |
| if current_word: | |
| result.append(current_word) | |
| return result | |
| def predict(tokens, model, tokenizer, device, max_length=128): | |
| tokens = split_sentence(' '.join(tokens.lower().split())) | |
| # Prepare the input | |
| encoding = prepare_input(tokens, tokenizer, max_length=max_length) | |
| word_ids = encoding.word_ids(batch_index=0) # List of word IDs | |
| # Move tensors to device | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| predictions = torch.argmax(logits, dim=-1).cpu().numpy()[0] | |
| # Decode tokens and labels | |
| tokens_decoded = tokenizer.convert_ids_to_tokens(input_ids.cpu().numpy()[0]) | |
| labels = [id2label.get(pred, 'O') for pred in predictions] | |
| # Align tokens with original word-level tokens | |
| aligned_predictions = [] | |
| previous_word_idx = None | |
| for token, label, word_idx in zip(tokens_decoded, labels, word_ids): | |
| if word_idx is None: | |
| continue | |
| if word_idx != previous_word_idx: | |
| aligned_predictions.append((tokens[word_idx], label)) | |
| previous_word_idx = word_idx | |
| return aligned_predictions | |
| def load_token_classifier(pretrained_token_classifier_path, device): | |
| # Load tokenizer and model | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(pretrained_token_classifier_path) | |
| token_classifier = DistilBertForTokenClassification.from_pretrained(pretrained_token_classifier_path) | |
| token_classifier.to(device) | |
| return token_classifier, tokenizer |