| from __future__ import annotations |
|
|
| import os |
| import re |
| from pathlib import Path |
| from typing import List, Tuple |
|
|
| import numpy as np |
| import torch |
| from transformers import ( |
| AutoModelForSequenceClassification, |
| AutoModelForTokenClassification, |
| AutoTokenizer, |
| ) |
|
|
| BASE_DIR = Path(__file__).resolve().parent |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| CLAUSE_MODEL_DIR = BASE_DIR / "clause_model_512" |
| CLASSIFICATION_MODEL_DIR = BASE_DIR / "classfication_model" |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| "roberta-base", |
| use_fast=True, |
| add_prefix_space=True |
| ) |
| clause_model = AutoModelForTokenClassification.from_pretrained( |
| str(CLAUSE_MODEL_DIR) |
| ).to(DEVICE).eval() |
|
|
| classification_model = AutoModelForSequenceClassification.from_pretrained( |
| str(CLASSIFICATION_MODEL_DIR) |
| ).to(DEVICE).eval() |
|
|
| labels2attrs = { |
| "##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"), |
| "##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"), |
| "##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"), |
| "##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"), |
| "##BASIC STATE": ("specific", "stative", "static"), |
| "##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"), |
| "##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"), |
| "##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"), |
| "##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"), |
| "##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"), |
| "##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"), |
| "##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"), |
| "##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"), |
| "##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"), |
| "##QUESTION": ("NA", "NA", "NA"), |
| "##IMPERATIVE": ("NA", "NA", "NA"), |
| "##NONSENSE": ("NA", "NA", "NA"), |
| "##OTHER": ("NA", "NA", "NA"), |
| } |
|
|
| label2index = {label: i for i, label in enumerate(labels2attrs.keys())} |
| index2label = {i: label for label, i in label2index.items()} |
|
|
|
|
| def split_sentences(text: str) -> List[str]: |
| text = re.sub(r"\s+", " ", text).strip() |
| if not text: |
| return [] |
| |
| sentences = re.split(r"(?<=[.!?])\s+", text) |
| return [s.strip() for s in sentences if s.strip()] |
|
|
|
|
| def auto_split(text: str, max_words: int = 200) -> List[str]: |
| sentences = split_sentences(text) |
| if not sentences: |
| return [] |
|
|
| snippets: List[str] = [] |
| current_words: List[str] = [] |
|
|
| for sent in sentences: |
| sent_words = sent.split() |
| if current_words and len(current_words) + len(sent_words) > max_words: |
| snippets.append(" ".join(current_words).strip()) |
| current_words = sent_words[:] |
| else: |
| current_words.extend(sent_words) |
|
|
| if current_words: |
| snippets.append(" ".join(current_words).strip()) |
|
|
| return snippets |
|
|
|
|
| def majority_vote(values: List[int]) -> int: |
| if not values: |
| return 1 |
| counts = np.bincount(values) |
| return int(np.argmax(counts)) |
|
|
|
|
| @torch.no_grad() |
| def get_pred_clause_labels(text: str) -> List[int]: |
| words = text.strip().split() |
| if not words: |
| return [] |
|
|
| enc = tokenizer( |
| words, |
| is_split_into_words=True, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| ) |
|
|
| word_ids = enc.word_ids(batch_index=0) |
| model_inputs = {k: v.to(DEVICE) for k, v in enc.items()} |
|
|
| logits = clause_model(**model_inputs).logits[0] |
| token_preds = logits.argmax(dim=-1).cpu().tolist() |
|
|
| aligned: List[List[int]] = [[] for _ in words] |
| for token_idx, word_id in enumerate(word_ids): |
| if word_id is None: |
| continue |
| aligned[word_id].append(token_preds[token_idx]) |
|
|
| pred_labels = [majority_vote(x) if x else 1 for x in aligned] |
| return pred_labels |
|
|
|
|
| def seg_clause(text: str) -> List[str]: |
| words = text.strip().split() |
| if not words: |
| return [] |
|
|
| labels = get_pred_clause_labels(text) |
| segmented_clauses: List[List[str]] = [] |
|
|
| prev_label = 2 |
| current_clause: List[str] | None = None |
|
|
| for token, label in zip(words, labels): |
| if prev_label == 2: |
| current_clause = [] |
|
|
| if current_clause is not None: |
| current_clause.append(token) |
|
|
| if label == 2 and prev_label in [0, 1]: |
| segmented_clauses.append(current_clause[:]) |
| current_clause = None |
|
|
| prev_label = label |
|
|
| if current_clause: |
| segmented_clauses.append(current_clause[:]) |
|
|
| return [" ".join(clause) for clause in segmented_clauses if clause] |
|
|
|
|
| @torch.no_grad() |
| def get_pred_classification_labels( |
| clauses: List[str], batch_size: int = 32 |
| ) -> List[Tuple[str, str]]: |
| results: List[Tuple[str, str]] = [] |
|
|
| for i in range(0, len(clauses), batch_size): |
| batch = clauses[i : i + batch_size] |
| enc = tokenizer( |
| batch, |
| padding=True, |
| truncation=True, |
| max_length=128, |
| return_tensors="pt", |
| ) |
| model_inputs = {k: v.to(DEVICE) for k, v in enc.items()} |
|
|
| logits = classification_model(**model_inputs).logits |
| pred_ids = logits.argmax(dim=-1).cpu().tolist() |
| pred_labels = [index2label[idx] for idx in pred_ids] |
|
|
| results.extend(list(zip(batch, pred_labels))) |
|
|
| return results |
|
|
|
|
| def run_pipeline(text: str): |
| snippets = auto_split(text) |
| all_clauses: List[str] = [] |
|
|
| for snippet in snippets: |
| all_clauses.extend(seg_clause(snippet)) |
|
|
| clause2labels = get_pred_classification_labels(all_clauses) |
| output_clauses = [(clause, str(i + 1)) for i, clause in enumerate(all_clauses)] |
|
|
| return output_clauses, clause2labels |