enhanced-replica-model-pack / scripts /inference /infer_bert_roberta.py
LUCIFerace's picture
Add files using upload-large-folder tool
4a0f6a5 verified
"""Run archived BERT and RoBERTa classifiers against a dataset folder."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import pandas as pd
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, DataCollatorWithPadding
REPO_ROOT = Path(__file__).resolve()
while REPO_ROOT != REPO_ROOT.parent and not (REPO_ROOT / "src").exists():
REPO_ROOT = REPO_ROOT.parent
MODELS_ROOT = REPO_ROOT / "models"
DATASET_ROOT = REPO_ROOT / "data" / "dataset"
OUTPUT_ROOT = REPO_ROOT / "outputs" / "plm"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_SPECS = {
"bert": {
"model_dir": MODELS_ROOT / "bert-final",
"hidden_size": 768,
"intermediate": 512,
"dropout": 0.5,
},
"roberta": {
"model_dir": MODELS_ROOT / "roberta-final",
"hidden_size": 1024,
"intermediate": 512,
"dropout": 0.3,
},
}
def load_jsonl(path: Path) -> list[dict]:
rows: list[dict] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
class TokenDataset(torch.utils.data.Dataset):
def __init__(self, encoded: dict[str, list[int]], labels: list[int]):
self.encoded = encoded
self.labels = labels
def __len__(self) -> int:
return len(self.encoded["input_ids"])
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
item = {key: torch.tensor(value[idx]) for key, value in self.encoded.items()}
item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
return item
class TransformerClassifier(nn.Module):
def __init__(self, base_model, hidden_size: int, intermediate: int, dropout: float, num_labels: int = 2):
super().__init__()
self.base = base_model
self.dropout = nn.Dropout(dropout)
self.intermediate = nn.Linear(hidden_size, intermediate)
self.activation = nn.ReLU()
self.classifier = nn.Linear(intermediate, num_labels)
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
outputs = self.base(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
cls = outputs.last_hidden_state[:, 0, :]
x = self.dropout(cls)
x = self.intermediate(x)
x = self.activation(x)
logits = self.classifier(x)
return type("Output", (object,), {"logits": logits})()
def build_model(model_name: str):
spec = MODEL_SPECS[model_name]
model_dir = spec["model_dir"]
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
base_model = AutoModel.from_config(config, trust_remote_code=True)
meta_path = model_dir / "model_meta.json"
meta = {}
if meta_path.exists():
meta = json.loads(meta_path.read_text(encoding="utf-8"))
classifier = TransformerClassifier(
base_model=base_model,
hidden_size=int(meta.get("hidden_size", spec["hidden_size"])),
intermediate=int(meta.get("intermediate", spec["intermediate"])),
dropout=float(meta.get("dropout", spec["dropout"])),
)
state_dict = torch.load(model_dir / "classifier_full_model.bin", map_location="cpu")
missing, unexpected = classifier.load_state_dict(state_dict, strict=False)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
classifier.to(DEVICE).eval()
return classifier, tokenizer, missing, unexpected
def predict_records(model, tokenizer, records: list[dict], batch_size: int, max_length: int) -> list[float]:
texts = [record["text"] for record in records]
labels = [int(record["label"]) for record in records]
encoded = tokenizer(texts, truncation=True, padding=False, max_length=max_length)
dataset = TokenDataset(encoded, labels)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=DataCollatorWithPadding(tokenizer),
)
all_probs: list[float] = []
with torch.no_grad():
for batch in loader:
batch = {key: value.to(DEVICE) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
outputs = model(**batch)
probs = torch.softmax(outputs.logits, dim=-1)[:, 1].cpu().numpy()
all_probs.extend(float(x) for x in probs)
return all_probs
def main() -> None:
parser = argparse.ArgumentParser(description="Run archived BERT and RoBERTa checkpoints.")
parser.add_argument("--dataset", required=True, help="Dataset name under data/dataset/")
parser.add_argument("--dataset-root", default=str(DATASET_ROOT))
parser.add_argument("--output-root", default=str(OUTPUT_ROOT))
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--max-length", type=int, default=512)
parser.add_argument("--include-train", action="store_true")
args = parser.parse_args()
dataset_dir = Path(args.dataset_root) / args.dataset
output_dir = Path(args.output_root) / args.dataset
output_dir.mkdir(parents=True, exist_ok=True)
splits = ["train", "dev", "test"] if args.include_train else ["dev", "test"]
for model_name in ("bert", "roberta"):
model, tokenizer, missing, unexpected = build_model(model_name)
print(f"[{model_name}] missing={len(missing)} unexpected={len(unexpected)}")
for split in splits:
split_path = dataset_dir / f"{split}.jsonl"
if not split_path.exists():
continue
records = load_jsonl(split_path)
if not records:
continue
probs = predict_records(model, tokenizer, records, args.batch_size, args.max_length)
frame = pd.DataFrame(
{
"text": [record["text"] for record in records],
"label": [int(record["label"]) for record in records],
"length": [len(str(record["text"])) for record in records],
"pred_prob": probs,
"pred_label_05": [int(prob >= 0.5) for prob in probs],
}
)
output_path = output_dir / f"{model_name}_{split}_predictions.csv"
frame.to_csv(output_path, index=False, encoding="utf-8")
print(f"saved {output_path}")
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
if __name__ == "__main__":
main()