| import torch |
| import transformers |
| import numpy as np |
|
|
|
|
| FEATURE_NAMES = ['nFix', 'FFD', 'GPT', 'TRT', 'fixProp'] |
| WINDOW_SIZE = 512 |
| OVERLAP = 50 |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| try: |
| from safetensors.torch import load_file as st_load_file |
| HAS_SAFETENSORS = True |
| except ImportError: |
| HAS_SAFETENSORS = False |
|
|
| class RobertaRegressionModel(torch.nn.Module): |
| def __init__(self, model_name='roberta-base'): |
| super().__init__() |
| self.roberta = transformers.RobertaModel.from_pretrained(model_name) |
| embed_size = 1024 if 'large' in model_name else 768 |
| self.decoder = torch.nn.Linear(embed_size, 5) |
|
|
| def forward(self, input_ids, attention_mask, predict_mask): |
| hidden = self.roberta(input_ids, attention_mask=attention_mask).last_hidden_state |
| Y_pred = self.decoder(hidden) |
| mask = (predict_mask == 0).unsqueeze(-1).expand_as(Y_pred).to(Y_pred.device) |
| Y_pred = Y_pred.masked_fill(mask, -1.0) |
| return Y_pred |
|
|
| class FixationsPredictor2: |
|
|
| def __init__(self, checkpoint_path, model_name='roberta-base'): |
| self.model_name = model_name |
| self.tokenizer = transformers.RobertaTokenizer.from_pretrained( |
| model_name, add_prefix_space=True |
| ) |
| self.model = RobertaRegressionModel(model_name).to(device) |
| self._load_checkpoint(checkpoint_path) |
| self.model.eval() |
|
|
| def _load_checkpoint(self, path): |
| import os |
| if path.endswith('.safetensors'): |
| if not HAS_SAFETENSORS: |
| raise ImportError('pip install safetensors') |
| self.model.load_state_dict(st_load_file(path, device=str(device))) |
| elif path.endswith('.pt') or path.endswith('.bin'): |
| self.model.load_state_dict(torch.load(path, map_location=device)) |
| else: |
| for ext in ['.safetensors', '.pt']: |
| if os.path.exists(path + ext): |
| self._load_checkpoint(path + ext) |
| return |
| raise FileNotFoundError(f'체크포인트 없음: {path}') |
|
|
| def _predict_with_sliding_window(self, input_ids_full, attention_mask_full): |
| seq_len = input_ids_full.shape[1] |
|
|
| if seq_len <= WINDOW_SIZE: |
| predict_mask = attention_mask_full.clone() |
| with torch.no_grad(): |
| pred = self.model(input_ids_full, attention_mask_full, predict_mask) |
| return pred.squeeze(0).cpu().numpy() |
|
|
| predictions = np.zeros((seq_len, 5), dtype=np.float32) |
| weights = np.zeros(seq_len, dtype=np.float32) |
| stride = WINDOW_SIZE - OVERLAP |
|
|
| start = 0 |
| while start < seq_len: |
| end = min(start + WINDOW_SIZE, seq_len) |
| ids_win = input_ids_full[:, start:end] |
| mask_win = attention_mask_full[:, start:end] |
| predict_mask = mask_win.clone() |
|
|
| with torch.no_grad(): |
| pred_win = self.model(ids_win, mask_win, predict_mask) |
| pred_np = pred_win.squeeze(0).cpu().numpy() |
|
|
| win_len = end - start |
| linear_w = np.ones(win_len, dtype=np.float32) |
| if start > 0: |
| ramp_len = min(OVERLAP, win_len) |
| linear_w[:ramp_len] = np.linspace(0, 1, ramp_len) |
| if end < seq_len: |
| ramp_len = min(OVERLAP, win_len) |
| linear_w[-ramp_len:] = np.linspace(1, 0, ramp_len) |
|
|
| for feat_i in range(5): |
| predictions[start:end, feat_i] += pred_np[:, feat_i] * linear_w |
| weights[start:end] += linear_w |
|
|
| if end == seq_len: |
| break |
| start += stride |
|
|
| nonzero = weights > 0 |
| predictions[nonzero] /= weights[nonzero, None] |
| return predictions |
|
|
| def _get_word_boundaries(self, input_ids): |
| tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in input_ids] |
| words = [] |
| current_word_tokens = [] |
| current_indices = [] |
|
|
| for i, tok in enumerate(tokens): |
| if tok in ('<s>', '</s>', '<pad>'): |
| if current_word_tokens: |
| words.append(current_indices) |
| current_word_tokens = [] |
| current_indices = [] |
| continue |
|
|
| if tok.startswith('Ġ') or not current_word_tokens: |
| if current_word_tokens: |
| words.append(current_indices) |
| current_word_tokens = [tok] |
| current_indices = [i] |
| else: |
| current_word_tokens.append(tok) |
| current_indices.append(i) |
|
|
| if current_word_tokens: |
| words.append(current_indices) |
|
|
| return words |
|
|
| def predict_raw_text(self, text): |
| words_list = text.strip().split() |
| encoding = self.tokenizer( |
| [words_list], |
| is_split_into_words=True, |
| return_tensors='pt', |
| truncation=False, |
| padding=False, |
| ) |
| input_ids = encoding['input_ids'].to(device) |
| attention_mask = encoding['attention_mask'].to(device) |
|
|
| token_preds = self._predict_with_sliding_window(input_ids, attention_mask) |
|
|
| word_boundaries = self._get_word_boundaries(input_ids.squeeze(0).cpu().tolist()) |
|
|
| word_features = np.zeros((len(word_boundaries), 5), dtype=np.float32) |
| for w_idx, token_indices in enumerate(word_boundaries): |
| first_tok = token_indices[0] |
| pred = token_preds[first_tok] |
| pred = np.clip(pred, 0, None) |
| word_features[w_idx] = pred |
|
|
| return word_features, words_list |
|
|
| def predict_and_remap_to_tokenizer(self, input_ids_rm, attention_mask_rm, rm_tokenizer): |
| batch_size = input_ids_rm.shape[0] |
| seq_len_rm = input_ids_rm.shape[1] |
|
|
| fixations_batch = [] |
| masks_batch = [] |
|
|
| for b in range(batch_size): |
| ids = input_ids_rm[b].cpu().tolist() |
| mask = attention_mask_rm[b].cpu().tolist() |
|
|
| pad_id = rm_tokenizer.pad_token_id |
| ids_no_pad = [i for i, m in zip(ids, mask) if m == 1 and i != pad_id] |
| text = rm_tokenizer.decode(ids_no_pad, skip_special_tokens=True) |
|
|
| word_features, _ = self.predict_raw_text(text) |
|
|
| remapped = self._remap_features_to_rm_tokens( |
| word_features, text, ids, mask, rm_tokenizer |
| ) |
|
|
| fixations_batch.append(remapped) |
| masks_batch.append(torch.tensor(mask, dtype=torch.long)) |
|
|
| fixations = torch.stack(fixations_batch) |
| fixations_attention_mask = torch.stack(masks_batch) |
|
|
| return fixations, fixations_attention_mask |
|
|
| def _compute_mapped_fixations(self, input_ids_rm, attention_mask_rm=None): |
| |
| if attention_mask_rm is None: |
| attention_mask_rm = torch.ones_like(input_ids_rm) |
| ids = input_ids_rm[0].cpu().tolist() |
| mask = attention_mask_rm[0].cpu().tolist() |
| pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else 1 |
| ids_no_pad = [i for i, m in zip(ids, mask) if m == 1 and i != pad_id] |
| text = self.tokenizer.decode(ids_no_pad, skip_special_tokens=True) |
| word_features, _ = self.predict_raw_text(text) |
| remapped = self._remap_features_to_rm_tokens( |
| word_features, text, ids, mask, self.tokenizer |
| ) |
| fixations = remapped.unsqueeze(0) |
| fix_attn = torch.tensor(mask, dtype=torch.long).unsqueeze(0) |
| return fixations, fix_attn, None, None, None, None |
|
|
| def _remap_features_to_rm_tokens(self, word_features, text, rm_input_ids, rm_mask, rm_tokenizer): |
| words = text.strip().split() |
| seq_len = len(rm_input_ids) |
| output = torch.zeros(seq_len, 5, dtype=torch.float32) |
|
|
| rm_tokens = rm_tokenizer.convert_ids_to_tokens(rm_input_ids) |
|
|
| word_to_rm_indices = _align_words_to_rm_tokens(words, rm_tokens, rm_tokenizer) |
|
|
| n_words = min(len(words), len(word_features)) |
| for w_idx in range(n_words): |
| if w_idx >= len(word_to_rm_indices): |
| break |
| indices = word_to_rm_indices[w_idx] |
| if not indices: |
| continue |
| feat = torch.tensor(word_features[w_idx], dtype=torch.float32) |
| if indices[0] < seq_len and rm_mask[indices[0]] == 1: |
| output[indices[0]] = feat |
| return output |
|
|
|
|
| def _align_words_to_rm_tokens(words, rm_tokens, rm_tokenizer): |
| special_ids = set(rm_tokenizer.all_special_ids) |
| word_to_indices = [] |
| tok_idx = 0 |
|
|
| for word in words: |
| indices = [] |
| chars_remaining = len(word) |
|
|
| while tok_idx < len(rm_tokens) and chars_remaining > 0: |
| tok = rm_tokens[tok_idx] |
| tok_id = rm_tokenizer.convert_tokens_to_ids(tok) |
| if tok_id in special_ids: |
| tok_idx += 1 |
| continue |
|
|
| tok_clean = tok.lstrip('Ġ▁ ') |
| indices.append(tok_idx) |
| chars_remaining -= len(tok_clean) |
| tok_idx += 1 |
|
|
| word_to_indices.append(indices) |
|
|
| return word_to_indices |
|
|