|
|
""" |
|
|
Contains functionality adapting a general-purpose BERT-type model |
|
|
for the QA task. The BertBasedQAModel fully aligns with the structure of |
|
|
other models (i.e., sub-classing QAModel for consistency); and stores a custom |
|
|
QAModule which specifies the wiring of the general-purpose model's representations |
|
|
with the linear NN layer needed for the QA task. |
|
|
|
|
|
Benefits: |
|
|
- Facilitates a **plug-and-play** selection of the underlying encoder model. |
|
|
- Follows a clean, composition pattern, avoiding double inheritance of both |
|
|
QAModel and torch.nn.Module which may introduce unnecessary complexity |
|
|
(e.g., which __init__() is called, which train() is called, etc.) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import random |
|
|
import json |
|
|
import numpy as np |
|
|
from dataclasses import asdict |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, List, Tuple |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
from src.models.base_qa_model import QAModel |
|
|
from src.config.model_configs import BertQAConfig |
|
|
from src.etl.types import QAExample, Prediction |
|
|
from src.evaluation.evaluator import Evaluator, Metrics |
|
|
from src.utils.constants import DEBUG_SEED |
|
|
|
|
|
|
|
|
def set_seed(seed: int = DEBUG_SEED) -> None: |
|
|
""" |
|
|
Set random seeds for reproducibility across Python, NumPy, and PyTorch. |
|
|
NOTE - this is mainly to facilitate experimentation progress; options such |
|
|
as torch.backends.cudnn.benchmark = False may hurt performance and thus running |
|
|
this function may need to be skipped in production. |
|
|
|
|
|
Relevant resources: |
|
|
- https://stackoverflow.com/questions/67581281/does-torch-manual-seed-include-the-operation-of-torch-cuda-manual-seed-all |
|
|
- https://docs.pytorch.org/docs/stable/notes/randomness.html |
|
|
|
|
|
# TODO - move to utilities file |
|
|
""" |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
|
torch.mps.manual_seed(seed) |
|
|
|
|
|
|
|
|
class QADataset(Dataset): |
|
|
""" |
|
|
Minimal wrapper to make Dict[str, QAExample] compatible with DataLoader. |
|
|
Facilitates batch processing during training (e.g., no manual index |
|
|
calculations to compute batch boundaries). |
|
|
|
|
|
# TODO - move to utilities file |
|
|
""" |
|
|
|
|
|
def __init__(self, examples_dict: Dict[str, QAExample]): |
|
|
"""DataLoader will call __getitem__(0), __getitem__(1), etc.""" |
|
|
self.examples = list(examples_dict.values()) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Returns total number of examples. DataLoader uses this for batching.""" |
|
|
return len(self.examples) |
|
|
|
|
|
def __getitem__(self, idx: int) -> QAExample: |
|
|
"""Returns a single example at the given index.""" |
|
|
return self.examples[idx] |
|
|
|
|
|
|
|
|
class BertBasedQAModel(QAModel): |
|
|
|
|
|
def __init__(self, config: BertQAConfig) -> None: |
|
|
super().__init__() |
|
|
|
|
|
set_seed() |
|
|
assert isinstance(config, BertQAConfig), "Incompatible configuration object." |
|
|
self.config = config |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.config.backbone_name, use_fast=True |
|
|
) |
|
|
self.qa_module = QAModule(config=self.config) |
|
|
|
|
|
|
|
|
|
|
|
test_encoding = self.tokenizer("testQ", "testC", return_tensors="pt") |
|
|
assert ( |
|
|
|
|
|
test_encoding["input_ids"][0, 0].item() |
|
|
== self.tokenizer.cls_token_id |
|
|
), "Model doesn't follow BERT's [CLS]-at-position-0 convention." |
|
|
|
|
|
@classmethod |
|
|
def load_from_experiment( |
|
|
cls, experiment_dir: Path, config_class, device: str = "mps" |
|
|
): |
|
|
""" |
|
|
Loads model from the experiment tracking directory. |
|
|
|
|
|
experiment_dir: Path to the experiment (e.g., 'experiments/<date_time>_bert-base_ALL_articles') |
|
|
device: by default we load into Apple MPS for local experimentation with predictions (e.g., threshold tuning) |
|
|
""" |
|
|
experiment_dir = Path(experiment_dir) |
|
|
model_dir = experiment_dir / "model" |
|
|
if not model_dir.exists(): |
|
|
raise FileNotFoundError(f"Model directory not found: {model_dir}") |
|
|
|
|
|
print(f"\nLoading model from experiment: {experiment_dir.name}") |
|
|
with open(experiment_dir / "config.json", "r") as f: |
|
|
config_dict = json.load(f) |
|
|
|
|
|
|
|
|
config_dict["device"] = device |
|
|
config = config_class(**config_dict) |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
tokenizer_path = model_dir / "tokenizer" |
|
|
if not tokenizer_path.exists(): |
|
|
raise FileNotFoundError(f"Tokenizer not found: {tokenizer_path}") |
|
|
model.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) |
|
|
|
|
|
weights_path = model_dir / "pytorch_model.bin" |
|
|
if not weights_path.exists(): |
|
|
raise FileNotFoundError(f"Model weights not found: {weights_path}") |
|
|
state_dict = torch.load(weights_path, map_location=device) |
|
|
model.qa_module.load_state_dict(state_dict) |
|
|
|
|
|
model.qa_module.eval() |
|
|
print("Model loaded succesfully and set to eval mode.") |
|
|
return model |
|
|
|
|
|
def train( |
|
|
self, |
|
|
train_examples: Optional[Dict[str, QAExample]] = None, |
|
|
val_examples: Optional[Dict[str, QAExample]] = None, |
|
|
) -> None: |
|
|
""" |
|
|
Trains the QA model on provided training examples. |
|
|
""" |
|
|
|
|
|
set_seed() |
|
|
|
|
|
|
|
|
self.qa_module.train() |
|
|
|
|
|
assert train_examples is not None, "Training examples cannot be None." |
|
|
assert len(train_examples) > 0, "Training examples cannot be empty." |
|
|
|
|
|
self._print_training_setup(train_examples, val_examples, self.config) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
self.qa_module.parameters(), |
|
|
lr=self.config.learning_rate, |
|
|
) |
|
|
|
|
|
|
|
|
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) |
|
|
dataset = QADataset(train_examples) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=self.config.batch_size, |
|
|
shuffle=True, |
|
|
collate_fn=lambda batch: batch, |
|
|
) |
|
|
print(f"Total batches per epoch: {len(dataloader)}") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
for epoch in range(self.config.num_epochs): |
|
|
print(f"{'='*70}") |
|
|
print(f"EPOCH {epoch + 1}/{self.config.num_epochs}") |
|
|
print(f"{'='*70}") |
|
|
total_loss = 0.0 |
|
|
|
|
|
|
|
|
set_truncated_examples = set() |
|
|
for batch_idx, batch_examples in enumerate(dataloader): |
|
|
|
|
|
batch_dict = {ex.question_id: ex for ex in batch_examples} |
|
|
qids, _, _, encoded = self._prepare_batch(batch_dict) |
|
|
assert ( |
|
|
len(qids) == encoded["input_ids"].shape[0] == len(batch_examples) |
|
|
), "Training shape mismatch after batch prepare." |
|
|
|
|
|
gold_starts, gold_ends = self._extract_gold_positions( |
|
|
batch_examples, encoded, set_truncated_examples |
|
|
) |
|
|
|
|
|
device = next(self.qa_module.parameters()).device |
|
|
gold_starts = gold_starts.to(device) |
|
|
gold_ends = gold_ends.to(device) |
|
|
|
|
|
start_logits, end_logits = self.qa_module( |
|
|
input_ids=encoded["input_ids"], |
|
|
attention_mask=encoded.get("attention_mask"), |
|
|
token_type_ids=encoded.get("token_type_ids"), |
|
|
) |
|
|
|
|
|
expected_shape = (len(batch_examples), encoded["input_ids"].shape[1]) |
|
|
assert ( |
|
|
start_logits.shape == expected_shape |
|
|
), f"start_logits shape {start_logits.shape} != expected {expected_shape}" |
|
|
assert ( |
|
|
end_logits.shape == expected_shape |
|
|
), f"end_logits shape {end_logits.shape} != expected {expected_shape}" |
|
|
|
|
|
start_loss = loss_fn(start_logits, gold_starts) |
|
|
end_loss = loss_fn(end_logits, gold_ends) |
|
|
|
|
|
|
|
|
loss = (start_loss + end_loss) / 2.0 |
|
|
assert loss.dim() == 0, f"Loss should be scalar, got shape {loss.shape}" |
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
total_loss += loss.item() |
|
|
|
|
|
if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(dataloader): |
|
|
avg_loss = total_loss / (batch_idx + 1) |
|
|
print( |
|
|
f" Batch {batch_idx + 1}/{len(dataloader)} | Avg Loss: {avg_loss:.4f}" |
|
|
) |
|
|
|
|
|
avg_epoch_loss = total_loss / len(dataloader) |
|
|
|
|
|
_, _ = self._print_epoch_summary( |
|
|
epoch=epoch + 1, |
|
|
total_epochs=self.config.num_epochs, |
|
|
avg_loss=avg_epoch_loss, |
|
|
num_truncated=len(set_truncated_examples), |
|
|
train_examples=train_examples, |
|
|
val_examples=val_examples, |
|
|
) |
|
|
|
|
|
print("Training Completed.") |
|
|
self.qa_module.eval() |
|
|
|
|
|
def _print_epoch_summary( |
|
|
self, |
|
|
epoch: int, |
|
|
total_epochs: int, |
|
|
avg_loss: float, |
|
|
num_truncated: int, |
|
|
train_examples: Dict[str, QAExample], |
|
|
val_examples: Optional[Dict[str, QAExample]] = None, |
|
|
) -> Tuple[Metrics, Optional[Metrics]]: |
|
|
if num_truncated > 0: |
|
|
print( |
|
|
f"{num_truncated} examples truncated throughout the epoch." |
|
|
f" Start & end answer tokens could not be identified." |
|
|
) |
|
|
print(f"\nEpoch {epoch}/{total_epochs} Complete | Average Loss: {avg_loss:.4f}") |
|
|
train_metrics = self._evaluate_and_print(train_examples, "Training") |
|
|
val_metrics = None |
|
|
if val_examples is not None: |
|
|
val_metrics = self._evaluate_and_print(val_examples, "Validation") |
|
|
|
|
|
|
|
|
self.qa_module.train() |
|
|
print(f"{'='*70}\n") |
|
|
return train_metrics, val_metrics |
|
|
|
|
|
def _evaluate_and_print( |
|
|
self, examples: Dict[str, QAExample], split_name: str |
|
|
) -> Metrics: |
|
|
print(f"Evaluating on {split_name} set...") |
|
|
predictions = self.predict(examples) |
|
|
metrics = Evaluator().evaluate(predictions, examples) |
|
|
print( |
|
|
f"{split_name} | EM: {metrics.exact_score:.2f}%, F1: {metrics.f1_score:.2f}%" |
|
|
) |
|
|
return metrics |
|
|
|
|
|
def _extract_gold_positions( |
|
|
self, |
|
|
examples: List[QAExample], |
|
|
encoded: BatchEncoding, |
|
|
set_truncated_examples: set[str], |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Maps character-level answer positions to token-level positions. |
|
|
In particular, for each example, the function computes (all offsets are start-inclusive, end-exclusive): |
|
|
- the answer offset within the context: [char_start, char_end) |
|
|
- each individual token's offset within the context: [token_char_start, token_char_end) |
|
|
|
|
|
For two ranges [A, B) and [C, D) to overlap: |
|
|
1. The first range should start before the second ends (A < D) |
|
|
2. The second range should start before the first ends (C < B) |
|
|
These are the conditions the function utilizes to determine an answer's overlap with a specific token. |
|
|
|
|
|
Finally, the function picks the FIRST and LAST tokens overlapping with the answer: |
|
|
those tokens can fully determine the answer and align with the QA training objective. |
|
|
|
|
|
Returns: |
|
|
- gold_starts: Tensor (size: batch size) with token index for answer start |
|
|
- gold_ends: Tensor (size: batch size) with token index for answer end |
|
|
""" |
|
|
offsets = encoded["offset_mapping"].tolist() |
|
|
batch_size = len(examples) |
|
|
assert ( |
|
|
len(offsets) == batch_size |
|
|
), f"Offset mapping size {len(offsets)} != batch size {batch_size}" |
|
|
|
|
|
|
|
|
gold_starts = [] |
|
|
gold_ends = [] |
|
|
for i, example in enumerate(examples): |
|
|
|
|
|
|
|
|
if example.is_impossible: |
|
|
gold_starts.append(0) |
|
|
gold_ends.append(0) |
|
|
continue |
|
|
assert ( |
|
|
len(example.answer_starts) > 0 |
|
|
), f"Answerable question {example.question_id} without valid answers." |
|
|
|
|
|
|
|
|
answer_text = example.answer_texts[0] |
|
|
char_start = example.answer_starts[0] |
|
|
char_end = char_start + len(answer_text) |
|
|
|
|
|
token_start = None |
|
|
token_end = None |
|
|
for token_idx, (token_char_start, token_char_end) in enumerate(offsets[i]): |
|
|
|
|
|
|
|
|
if token_char_start == 0 and token_char_end == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
if token_start is None and token_char_end > char_start: |
|
|
token_start = token_idx |
|
|
|
|
|
|
|
|
if token_char_start < char_end: |
|
|
token_end = token_idx |
|
|
|
|
|
if token_start is None or token_end is None: |
|
|
|
|
|
|
|
|
|
|
|
set_truncated_examples.add(example.question_id) |
|
|
|
|
|
gold_starts.append(-1) |
|
|
gold_ends.append(-1) |
|
|
continue |
|
|
assert ( |
|
|
token_start <= token_end |
|
|
), f"Invalid token span: start {token_start} > end {token_end}" |
|
|
|
|
|
gold_starts.append(token_start) |
|
|
gold_ends.append(token_end) |
|
|
|
|
|
gold_starts_tensor = torch.tensor(gold_starts, dtype=torch.long) |
|
|
gold_ends_tensor = torch.tensor(gold_ends, dtype=torch.long) |
|
|
assert ( |
|
|
len(examples) == len(gold_starts_tensor) == len(gold_ends_tensor) |
|
|
), "Ground-truth token shape mismatch." |
|
|
return gold_starts_tensor, gold_ends_tensor |
|
|
|
|
|
def predict( |
|
|
self, examples: Dict[str, QAExample], threshold_override: Optional[float] = None |
|
|
) -> Dict[str, Prediction]: |
|
|
""" |
|
|
Wrapper that automatically chunks large prediction requests to avoid OOM. |
|
|
""" |
|
|
self.qa_module.eval() |
|
|
assert isinstance(examples, dict), "Incompatible input examples type." |
|
|
assert len(examples) > 0, "No examples to run prediction on." |
|
|
|
|
|
eval_batch_size = self.config.eval_batch_size |
|
|
if len(examples) <= eval_batch_size: |
|
|
return self._predict_batch(examples, threshold_override) |
|
|
|
|
|
all_qids = list(examples.keys()) |
|
|
all_predictions = {} |
|
|
|
|
|
for i in range(0, len(all_qids), eval_batch_size): |
|
|
batch_qids = all_qids[i : i + eval_batch_size] |
|
|
batch_examples = {qid: examples[qid] for qid in batch_qids} |
|
|
all_predictions.update( |
|
|
self._predict_batch(batch_examples, threshold_override) |
|
|
) |
|
|
|
|
|
return all_predictions |
|
|
|
|
|
def _predict_batch( |
|
|
self, examples: Dict[str, QAExample], threshold_override: Optional[float] = None |
|
|
) -> Dict[str, Prediction]: |
|
|
""" |
|
|
Processes a single batch of examples: |
|
|
encapsulates the forward pass + logic to determine the final model's response |
|
|
based on the predicted logits for each token being the start/end of the true answer. |
|
|
""" |
|
|
|
|
|
threshold = ( |
|
|
threshold_override |
|
|
if threshold_override is not None |
|
|
else self.config.no_answer_threshold |
|
|
) |
|
|
|
|
|
|
|
|
qids, _, contexts, encoded = self._prepare_batch(examples) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
start_logits, end_logits = self.qa_module( |
|
|
input_ids=encoded["input_ids"], |
|
|
attention_mask=encoded.get("attention_mask"), |
|
|
token_type_ids=encoded.get("token_type_ids"), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if encoded.get("token_type_ids") is not None: |
|
|
|
|
|
context_mask = (encoded["token_type_ids"] == 1) & ( |
|
|
encoded["attention_mask"] == 1 |
|
|
) |
|
|
else: |
|
|
|
|
|
context_mask = encoded["attention_mask"] == 1 |
|
|
|
|
|
context_mask[:, 0] = True |
|
|
context_mask = context_mask.to(self.config.device) |
|
|
|
|
|
|
|
|
|
|
|
MIN_NUMBER = torch.finfo(start_logits.dtype).min |
|
|
start_logits = start_logits.masked_fill(~context_mask, MIN_NUMBER) |
|
|
end_logits = end_logits.masked_fill(~context_mask, MIN_NUMBER) |
|
|
|
|
|
|
|
|
|
|
|
best_start_indices = start_logits.argmax(dim=1) |
|
|
best_end_indices = end_logits.argmax(dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offsets = encoded["offset_mapping"].tolist() |
|
|
predictions = {} |
|
|
for i, qid in enumerate(qids): |
|
|
|
|
|
if not context_mask[i, 1:].any(): |
|
|
predictions[qid] = Prediction.null(question_id=qid) |
|
|
continue |
|
|
|
|
|
start_idx = best_start_indices[i].item() |
|
|
end_idx = best_end_indices[i].item() |
|
|
|
|
|
|
|
|
null_score = start_logits[i, 0].item() + end_logits[i, 0].item() |
|
|
best_span_score = ( |
|
|
start_logits[i, start_idx].item() + end_logits[i, end_idx].item() |
|
|
) |
|
|
|
|
|
if best_span_score <= null_score + threshold: |
|
|
predictions[qid] = Prediction.null(question_id=qid) |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if end_idx < start_idx: |
|
|
predictions[qid] = Prediction.null(question_id=qid) |
|
|
continue |
|
|
|
|
|
|
|
|
start_char, _ = offsets[i][start_idx] |
|
|
_, end_char = offsets[i][end_idx] |
|
|
|
|
|
|
|
|
|
|
|
if start_char == 0 and end_char == 0: |
|
|
predictions[qid] = Prediction.null(question_id=qid) |
|
|
continue |
|
|
|
|
|
assert end_char >= start_char, ( |
|
|
f"BUG: Invalid character span [{start_char}, {end_char}] " |
|
|
f"for valid token span [{start_idx}, {end_idx}] in question {qid}. " |
|
|
f"This indicates a problem with offset mapping or token masking." |
|
|
) |
|
|
|
|
|
|
|
|
answer_text = contexts[i][start_char:end_char].strip() |
|
|
|
|
|
if not answer_text: |
|
|
predictions[qid] = Prediction.null(question_id=qid) |
|
|
continue |
|
|
|
|
|
|
|
|
predictions[qid] = Prediction( |
|
|
question_id=qid, |
|
|
predicted_answer=answer_text, |
|
|
confidence=1.0, |
|
|
is_impossible=False, |
|
|
) |
|
|
return predictions |
|
|
|
|
|
def _prepare_batch( |
|
|
self, examples: Dict[str, QAExample] |
|
|
) -> Tuple[List[str], List[str], List[str], BatchEncoding]: |
|
|
""" |
|
|
Extracts questions and contexts in consistent order, then tokenizes them. |
|
|
""" |
|
|
qids = list(examples.keys()) |
|
|
questions = [examples[qid].question for qid in qids] |
|
|
contexts = [examples[qid].context for qid in qids] |
|
|
encoded = self._encode_pairs(questions, contexts) |
|
|
return qids, questions, contexts, encoded |
|
|
|
|
|
def _encode_pairs(self, questions: list[str], contexts: list[str]) -> BatchEncoding: |
|
|
""" |
|
|
Standardizes tokenization across all stages (train/inference). |
|
|
For more information, refer to the HF documentation, for example see: |
|
|
https://huggingface.co/docs/transformers/pad_truncation regarding sequence padding/trunctation. |
|
|
""" |
|
|
assert len(questions) == len( |
|
|
contexts |
|
|
), "Question and context lists are incompatible." |
|
|
return self.tokenizer( |
|
|
text=questions, |
|
|
text_pair=contexts, |
|
|
truncation="only_second", |
|
|
max_length=self.config.max_sequence_length, |
|
|
padding="max_length", |
|
|
return_offsets_mapping=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _print_training_setup( |
|
|
train_examples: Dict[str, QAExample], |
|
|
val_examples: Optional[Dict[str, QAExample]], |
|
|
config: BertQAConfig, |
|
|
) -> None: |
|
|
"""Print training setup information including data splits and configuration.""" |
|
|
answerable_count = sum( |
|
|
1 for ex in train_examples.values() if not ex.is_impossible |
|
|
) |
|
|
unanswerable_count = len(train_examples) - answerable_count |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"TRAINING SETUP") |
|
|
print(f"{'='*70}") |
|
|
print(f"Total examples: {len(train_examples)}") |
|
|
print(f" Answerable: {answerable_count}") |
|
|
print(f" Unanswerable: {unanswerable_count}") |
|
|
assert len(train_examples) > 0, "No training examples!" |
|
|
|
|
|
if val_examples is not None: |
|
|
val_answerable = sum( |
|
|
1 for ex in val_examples.values() if not ex.is_impossible |
|
|
) |
|
|
val_unanswerable = len(val_examples) - val_answerable |
|
|
print( |
|
|
f"Validation: {len(val_examples)} total ({val_answerable} answerable, {val_unanswerable} unanswerable)" |
|
|
) |
|
|
|
|
|
print(f"\nConfiguration:") |
|
|
print(json.dumps(asdict(config), indent=2)) |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
|
|
|
class QAModule(torch.nn.Module): |
|
|
""" |
|
|
Defines the initialization & wiring of a general-purpose encoder with a linear NN layer |
|
|
in order to extract logits reflecting the probability of each token being |
|
|
the start/end of the answer. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: BertQAConfig) -> None: |
|
|
super().__init__() |
|
|
assert isinstance(config, BertQAConfig), "Incompatible configuration object." |
|
|
self.encoder = AutoModel.from_pretrained(config.backbone_name) |
|
|
|
|
|
|
|
|
self.linear_head = torch.nn.Linear( |
|
|
in_features=self.encoder.config.hidden_size, out_features=2 |
|
|
) |
|
|
|
|
|
|
|
|
self.to(config.device) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
input_ids: tokenized integer IDs from the vocabulary |
|
|
attention_mask: binary mask reflecting actual token Vs padding token |
|
|
token_type_ids: binary mask reflecting the segment: sentence A Vs sentence B |
|
|
""" |
|
|
|
|
|
dev = next(self.parameters()).device |
|
|
input_ids = input_ids.to(dev) |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(dev) |
|
|
if token_type_ids is not None: |
|
|
token_type_ids = token_type_ids.to(dev) |
|
|
|
|
|
encoder_output = self.encoder( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
) |
|
|
|
|
|
encoder_output_embeddings = encoder_output.last_hidden_state |
|
|
|
|
|
logits = self.linear_head(encoder_output_embeddings) |
|
|
start_logits, end_logits = logits[:, :, 0], logits[:, :, 1] |
|
|
return start_logits, end_logits |
|
|
|