BabakScrapes's picture
Upgrade used packages
7c93384 verified
Raw
History Blame Contribute Delete
5.92 kB
from __future__ import annotations
import os
import re
from pathlib import Path
from typing import List, Tuple
import numpy as np
import torch
from transformers import (
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
)
BASE_DIR = Path(__file__).resolve().parent
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Keep the existing folder names from the repo.
CLAUSE_MODEL_DIR = BASE_DIR / "clause_model_512"
CLASSIFICATION_MODEL_DIR = BASE_DIR / "classfication_model"
tokenizer = AutoTokenizer.from_pretrained(
"roberta-base",
use_fast=True,
add_prefix_space=True
)
clause_model = AutoModelForTokenClassification.from_pretrained(
str(CLAUSE_MODEL_DIR)
).to(DEVICE).eval()
classification_model = AutoModelForSequenceClassification.from_pretrained(
str(CLASSIFICATION_MODEL_DIR)
).to(DEVICE).eval()
labels2attrs = {
"##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"),
"##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"),
"##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"),
"##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"),
"##BASIC STATE": ("specific", "stative", "static"),
"##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"),
"##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"),
"##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"),
"##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"),
"##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"),
"##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"),
"##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"),
"##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"),
"##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"),
"##QUESTION": ("NA", "NA", "NA"),
"##IMPERATIVE": ("NA", "NA", "NA"),
"##NONSENSE": ("NA", "NA", "NA"),
"##OTHER": ("NA", "NA", "NA"),
}
label2index = {label: i for i, label in enumerate(labels2attrs.keys())}
index2label = {i: label for label, i in label2index.items()}
def split_sentences(text: str) -> List[str]:
text = re.sub(r"\s+", " ", text).strip()
if not text:
return []
# Lightweight sentence splitting to avoid spaCy.
sentences = re.split(r"(?<=[.!?])\s+", text)
return [s.strip() for s in sentences if s.strip()]
def auto_split(text: str, max_words: int = 200) -> List[str]:
sentences = split_sentences(text)
if not sentences:
return []
snippets: List[str] = []
current_words: List[str] = []
for sent in sentences:
sent_words = sent.split()
if current_words and len(current_words) + len(sent_words) > max_words:
snippets.append(" ".join(current_words).strip())
current_words = sent_words[:]
else:
current_words.extend(sent_words)
if current_words:
snippets.append(" ".join(current_words).strip())
return snippets
def majority_vote(values: List[int]) -> int:
if not values:
return 1
counts = np.bincount(values)
return int(np.argmax(counts))
@torch.no_grad()
def get_pred_clause_labels(text: str) -> List[int]:
words = text.strip().split()
if not words:
return []
enc = tokenizer(
words,
is_split_into_words=True,
return_tensors="pt",
truncation=True,
max_length=512,
)
word_ids = enc.word_ids(batch_index=0)
model_inputs = {k: v.to(DEVICE) for k, v in enc.items()}
logits = clause_model(**model_inputs).logits[0]
token_preds = logits.argmax(dim=-1).cpu().tolist()
aligned: List[List[int]] = [[] for _ in words]
for token_idx, word_id in enumerate(word_ids):
if word_id is None:
continue
aligned[word_id].append(token_preds[token_idx])
pred_labels = [majority_vote(x) if x else 1 for x in aligned]
return pred_labels
def seg_clause(text: str) -> List[str]:
words = text.strip().split()
if not words:
return []
labels = get_pred_clause_labels(text)
segmented_clauses: List[List[str]] = []
prev_label = 2
current_clause: List[str] | None = None
for token, label in zip(words, labels):
if prev_label == 2:
current_clause = []
if current_clause is not None:
current_clause.append(token)
if label == 2 and prev_label in [0, 1]:
segmented_clauses.append(current_clause[:])
current_clause = None
prev_label = label
if current_clause:
segmented_clauses.append(current_clause[:])
return [" ".join(clause) for clause in segmented_clauses if clause]
@torch.no_grad()
def get_pred_classification_labels(
clauses: List[str], batch_size: int = 32
) -> List[Tuple[str, str]]:
results: List[Tuple[str, str]] = []
for i in range(0, len(clauses), batch_size):
batch = clauses[i : i + batch_size]
enc = tokenizer(
batch,
padding=True,
truncation=True,
max_length=128,
return_tensors="pt",
)
model_inputs = {k: v.to(DEVICE) for k, v in enc.items()}
logits = classification_model(**model_inputs).logits
pred_ids = logits.argmax(dim=-1).cpu().tolist()
pred_labels = [index2label[idx] for idx in pred_ids]
results.extend(list(zip(batch, pred_labels)))
return results
def run_pipeline(text: str):
snippets = auto_split(text)
all_clauses: List[str] = []
for snippet in snippets:
all_clauses.extend(seg_clause(snippet))
clause2labels = get_pred_classification_labels(all_clauses)
output_clauses = [(clause, str(i + 1)) for i, clause in enumerate(all_clauses)]
return output_clauses, clause2labels