from __future__ import annotations import os import re from pathlib import Path from typing import List, Tuple import numpy as np import torch from transformers import ( AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer, ) BASE_DIR = Path(__file__).resolve().parent DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Keep the existing folder names from the repo. CLAUSE_MODEL_DIR = BASE_DIR / "clause_model_512" CLASSIFICATION_MODEL_DIR = BASE_DIR / "classfication_model" tokenizer = AutoTokenizer.from_pretrained( "roberta-base", use_fast=True, add_prefix_space=True ) clause_model = AutoModelForTokenClassification.from_pretrained( str(CLAUSE_MODEL_DIR) ).to(DEVICE).eval() classification_model = AutoModelForSequenceClassification.from_pretrained( str(CLASSIFICATION_MODEL_DIR) ).to(DEVICE).eval() labels2attrs = { "##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"), "##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"), "##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"), "##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"), "##BASIC STATE": ("specific", "stative", "static"), "##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"), "##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"), "##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"), "##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"), "##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"), "##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"), "##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"), "##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"), "##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"), "##QUESTION": ("NA", "NA", "NA"), "##IMPERATIVE": ("NA", "NA", "NA"), "##NONSENSE": ("NA", "NA", "NA"), "##OTHER": ("NA", "NA", "NA"), } label2index = {label: i for i, label in enumerate(labels2attrs.keys())} index2label = {i: label for label, i in label2index.items()} def split_sentences(text: str) -> List[str]: text = re.sub(r"\s+", " ", text).strip() if not text: return [] # Lightweight sentence splitting to avoid spaCy. sentences = re.split(r"(?<=[.!?])\s+", text) return [s.strip() for s in sentences if s.strip()] def auto_split(text: str, max_words: int = 200) -> List[str]: sentences = split_sentences(text) if not sentences: return [] snippets: List[str] = [] current_words: List[str] = [] for sent in sentences: sent_words = sent.split() if current_words and len(current_words) + len(sent_words) > max_words: snippets.append(" ".join(current_words).strip()) current_words = sent_words[:] else: current_words.extend(sent_words) if current_words: snippets.append(" ".join(current_words).strip()) return snippets def majority_vote(values: List[int]) -> int: if not values: return 1 counts = np.bincount(values) return int(np.argmax(counts)) @torch.no_grad() def get_pred_clause_labels(text: str) -> List[int]: words = text.strip().split() if not words: return [] enc = tokenizer( words, is_split_into_words=True, return_tensors="pt", truncation=True, max_length=512, ) word_ids = enc.word_ids(batch_index=0) model_inputs = {k: v.to(DEVICE) for k, v in enc.items()} logits = clause_model(**model_inputs).logits[0] token_preds = logits.argmax(dim=-1).cpu().tolist() aligned: List[List[int]] = [[] for _ in words] for token_idx, word_id in enumerate(word_ids): if word_id is None: continue aligned[word_id].append(token_preds[token_idx]) pred_labels = [majority_vote(x) if x else 1 for x in aligned] return pred_labels def seg_clause(text: str) -> List[str]: words = text.strip().split() if not words: return [] labels = get_pred_clause_labels(text) segmented_clauses: List[List[str]] = [] prev_label = 2 current_clause: List[str] | None = None for token, label in zip(words, labels): if prev_label == 2: current_clause = [] if current_clause is not None: current_clause.append(token) if label == 2 and prev_label in [0, 1]: segmented_clauses.append(current_clause[:]) current_clause = None prev_label = label if current_clause: segmented_clauses.append(current_clause[:]) return [" ".join(clause) for clause in segmented_clauses if clause] @torch.no_grad() def get_pred_classification_labels( clauses: List[str], batch_size: int = 32 ) -> List[Tuple[str, str]]: results: List[Tuple[str, str]] = [] for i in range(0, len(clauses), batch_size): batch = clauses[i : i + batch_size] enc = tokenizer( batch, padding=True, truncation=True, max_length=128, return_tensors="pt", ) model_inputs = {k: v.to(DEVICE) for k, v in enc.items()} logits = classification_model(**model_inputs).logits pred_ids = logits.argmax(dim=-1).cpu().tolist() pred_labels = [index2label[idx] for idx in pred_ids] results.extend(list(zip(batch, pred_labels))) return results def run_pipeline(text: str): snippets = auto_split(text) all_clauses: List[str] = [] for snippet in snippets: all_clauses.extend(seg_clause(snippet)) clause2labels = get_pred_classification_labels(all_clauses) output_clauses = [(clause, str(i + 1)) for i, clause in enumerate(all_clauses)] return output_clauses, clause2labels