|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from torch_geometric.data import Data, DataLoader |
|
|
|
|
|
|
|
|
from mecari.featurizers.lexical import ( |
|
|
LexicalNGramFeaturizer as LexFeaturizer, |
|
|
Morpheme as LexMorpheme, |
|
|
) |
|
|
|
|
|
|
|
|
"""Data module for lexical-graph training using prebuilt .pt graphs only.""" |
|
|
|
|
|
|
|
|
|
|
|
class _PtGraphDataset(Dataset): |
|
|
"""Prebuilt PyG graph tensors saved as .pt per sentence. |
|
|
|
|
|
Each file is expected to be a dict with keys: |
|
|
- 'graph': torch_geometric.data.Data |
|
|
- 'source_id': str (used for split) |
|
|
- optional: 'text' |
|
|
""" |
|
|
|
|
|
def __init__(self, files: List[str]) -> None: |
|
|
self.files = files |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.files) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Data: |
|
|
path = self.files[idx] |
|
|
obj = torch.load(path, map_location="cpu") |
|
|
if isinstance(obj, dict) and "graph" in obj: |
|
|
data = obj["graph"] |
|
|
else: |
|
|
data = obj |
|
|
if not isinstance(data, Data): |
|
|
raise RuntimeError(f"Invalid graph object in: {path}") |
|
|
data.data_index = idx |
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import torch.serialization |
|
|
from torch_geometric.data.data import DataEdgeAttr |
|
|
|
|
|
torch.serialization.add_safe_globals([DataEdgeAttr, Data]) |
|
|
except (ImportError, AttributeError): |
|
|
pass |
|
|
|
|
|
|
|
|
class DataModule(pl.LightningDataModule): |
|
|
"""Loads .pt graphs and builds lexical graph features for training.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
annotations_dir: str = "annotations", |
|
|
batch_size: int = 32, |
|
|
num_workers: int = 0, |
|
|
max_files: Optional[int] = None, |
|
|
use_bidirectional_edges: bool = True, |
|
|
annotations_override_dir: Optional[str] = None, |
|
|
silent: bool = False, |
|
|
lexical_feature_dim: int = 100000, |
|
|
lexical_max_features: int = 20, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.annotations_dir = annotations_dir |
|
|
self.annotations_override_dir = annotations_override_dir |
|
|
self.batch_size = batch_size |
|
|
self.num_workers = num_workers |
|
|
self.max_files = max_files |
|
|
self.use_bidirectional_edges = True |
|
|
self.silent = silent |
|
|
self.lexical_feature_dim = lexical_feature_dim |
|
|
self.lexical_max_features = int(lexical_max_features) |
|
|
self.use_bidirectional_edges = bool(use_bidirectional_edges) |
|
|
|
|
|
|
|
|
self.train_dataset = [] |
|
|
self.val_dataset = [] |
|
|
self.test_dataset = [] |
|
|
|
|
|
self._lex_featurizer = LexFeaturizer(dim=int(self.lexical_feature_dim), add_bias=True) |
|
|
|
|
|
self.pos_to_id = { |
|
|
"名詞": 1, |
|
|
"動詞": 2, |
|
|
"形容詞": 3, |
|
|
"副詞": 4, |
|
|
"助詞": 5, |
|
|
"助動詞": 6, |
|
|
"接続詞": 7, |
|
|
"連体詞": 8, |
|
|
"感動詞": 9, |
|
|
"形状詞": 10, |
|
|
"補助記号": 11, |
|
|
"接頭辞": 12, |
|
|
"接尾辞": 13, |
|
|
"特殊": 14, |
|
|
} |
|
|
self.id_to_pos = {v: k for k, v in self.pos_to_id.items()} |
|
|
|
|
|
def create_graph_from_morphemes_data(self, *args, **kwargs) -> Optional[Data]: |
|
|
"""Create a lexical graph from morpheme data (or candidates).""" |
|
|
if "candidates" in kwargs: |
|
|
candidates = kwargs.pop("candidates") |
|
|
text = kwargs.get("text", "") |
|
|
morphemes_edges = self._build_graph_from_candidates(candidates, text) |
|
|
if not morphemes_edges: |
|
|
return None |
|
|
kwargs["morphemes"] = morphemes_edges["morphemes"] |
|
|
kwargs["edges"] = morphemes_edges["edges"] |
|
|
return self._create_lexical_graph(*args, **kwargs) |
|
|
|
|
|
|
|
|
def compute_lexical_features(self, morphemes: List[Dict], text: str) -> List[Dict]: |
|
|
"""Add lexical_features to each morpheme using Mecari's lexical featurizer. |
|
|
|
|
|
Requires mecari.featurizers.lexical to be importable. Raises a clear error |
|
|
if the featurizer is unavailable (training/inference depend on it). |
|
|
""" |
|
|
if not morphemes: |
|
|
return morphemes |
|
|
|
|
|
for m in morphemes: |
|
|
try: |
|
|
morph_obj = LexMorpheme( |
|
|
surf=m.get("surface", ""), |
|
|
lemma=m.get("base_form", ""), |
|
|
pos=m.get("pos", "*"), |
|
|
pos1=m.get("pos_detail1", "*"), |
|
|
ctype=m.get("inflection_type", "*"), |
|
|
cform=m.get("inflection_form", "*"), |
|
|
reading=m.get("reading", "*"), |
|
|
) |
|
|
st = m.get("start_pos", 0) |
|
|
ed = m.get("end_pos", st + len(m.get("surface", ""))) |
|
|
prev_char = text[st - 1] if st > 0 else None |
|
|
next_char = text[ed] if ed < len(text) else None |
|
|
feats = self._lex_featurizer.unigram_feats(morph_obj, prev_char, next_char) |
|
|
m["lexical_features"] = feats |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
return morphemes |
|
|
|
|
|
def _create_lexical_graph( |
|
|
self, morphemes: List[Dict], edges: List[Dict], text: str, for_training: bool = True |
|
|
) -> Optional[Data]: |
|
|
"""Build a graph using lexical features.""" |
|
|
if not morphemes: |
|
|
return None |
|
|
|
|
|
|
|
|
all_indices = [] |
|
|
all_values = [] |
|
|
all_lengths = [] |
|
|
annotations = [] |
|
|
valid_mask = [] |
|
|
|
|
|
max_features = 0 |
|
|
for morpheme in morphemes: |
|
|
lexical_feats = morpheme.get("lexical_features", []) |
|
|
indices = [] |
|
|
values = [] |
|
|
for idx, val in lexical_feats: |
|
|
if 0 <= idx < self.lexical_feature_dim: |
|
|
indices.append(idx) |
|
|
values.append(val) |
|
|
all_lengths.append(len(indices)) |
|
|
max_features = max(max_features, len(indices)) |
|
|
|
|
|
all_indices.append(indices) |
|
|
all_values.append(values) |
|
|
|
|
|
if for_training: |
|
|
annotation = morpheme.get("annotation", "?") |
|
|
if annotation == "+": |
|
|
annotations.append(1) |
|
|
valid_mask.append(True) |
|
|
elif annotation == "-": |
|
|
annotations.append(0) |
|
|
valid_mask.append(True) |
|
|
else: |
|
|
annotations.append(0) |
|
|
valid_mask.append(False) |
|
|
|
|
|
|
|
|
FIXED_MAX_FEATURES = int(getattr(self, "lexical_max_features", 20)) |
|
|
|
|
|
padded_indices = [] |
|
|
padded_values = [] |
|
|
for indices, values in zip(all_indices, all_values): |
|
|
if len(indices) > FIXED_MAX_FEATURES: |
|
|
padded_indices.append(indices[:FIXED_MAX_FEATURES]) |
|
|
padded_values.append(values[:FIXED_MAX_FEATURES]) |
|
|
else: |
|
|
pad_length = FIXED_MAX_FEATURES - len(indices) |
|
|
padded_indices.append(indices + [0] * pad_length) |
|
|
padded_values.append(values + [0.0] * pad_length) |
|
|
|
|
|
edge_index = self._build_edge_index(edges, len(morphemes)) |
|
|
|
|
|
|
|
|
pos_ids = [] |
|
|
for m in morphemes: |
|
|
pos = m.get("pos", "*") |
|
|
pos_ids.append(self.pos_to_id.get(pos, 0)) |
|
|
|
|
|
graph_data = Data( |
|
|
lexical_indices=torch.tensor(padded_indices, dtype=torch.long), |
|
|
lexical_values=torch.tensor(padded_values, dtype=torch.float32), |
|
|
lexical_lengths=torch.tensor(all_lengths, dtype=torch.long), |
|
|
edge_index=edge_index, |
|
|
num_nodes=len(morphemes), |
|
|
) |
|
|
graph_data.pos_ids = torch.tensor(pos_ids, dtype=torch.long) |
|
|
if for_training: |
|
|
graph_data.y = torch.tensor(annotations, dtype=torch.float32) |
|
|
graph_data.valid_mask = torch.tensor(valid_mask, dtype=torch.bool) |
|
|
|
|
|
return graph_data |
|
|
|
|
|
def _build_edge_index(self, edges: List[Dict], num_nodes: int) -> torch.Tensor: |
|
|
"""Build a PyG edge_index tensor from edge dicts.""" |
|
|
if not edges: |
|
|
return torch.tensor([[], []], dtype=torch.long) |
|
|
|
|
|
source_indices = [] |
|
|
target_indices = [] |
|
|
|
|
|
for edge in edges: |
|
|
source = edge.get("source_idx", 0) |
|
|
target = edge.get("target_idx", 0) |
|
|
|
|
|
if 0 <= source < num_nodes and 0 <= target < num_nodes: |
|
|
source_indices.append(source) |
|
|
target_indices.append(target) |
|
|
if self.use_bidirectional_edges: |
|
|
source_indices.append(target) |
|
|
target_indices.append(source) |
|
|
|
|
|
if not source_indices: |
|
|
return torch.tensor([[], []], dtype=torch.long) |
|
|
|
|
|
return torch.tensor([source_indices, target_indices], dtype=torch.long) |
|
|
|
|
|
def _load_kwdlc_ids(self, ids_file: str) -> set: |
|
|
"""Load KWDLC ID list (one ID per line).""" |
|
|
ids = set() |
|
|
if ids_file and os.path.exists(ids_file): |
|
|
with open(ids_file, "r") as f: |
|
|
for line in f: |
|
|
ids.add(line.strip()) |
|
|
return ids |
|
|
|
|
|
def load_annotation_data(self, max_files: Optional[int] = None) -> List[Dict]: |
|
|
"""Detect and list available .pt annotation graph files.""" |
|
|
if os.path.isdir(self.annotations_dir): |
|
|
pt_files = [ |
|
|
os.path.join(self.annotations_dir, fn) |
|
|
for fn in sorted(os.listdir(self.annotations_dir)) |
|
|
if fn.endswith(".pt") |
|
|
] |
|
|
if pt_files: |
|
|
if max_files is not None: |
|
|
pt_files = pt_files[:max_files] |
|
|
return [{"_mode": "pt", "_pt_files": pt_files}] |
|
|
raise FileNotFoundError(f"No annotation graphs found under: {self.annotations_dir}") |
|
|
|
|
|
def setup(self, stage: Optional[str] = None) -> None: |
|
|
"""Build train/val/test datasets from discovered .pt files.""" |
|
|
annotation_data = self.load_annotation_data(max_files=self.max_files) |
|
|
|
|
|
if not annotation_data: |
|
|
self.train_dataset = [] |
|
|
self.val_dataset = [] |
|
|
self.test_dataset = [] |
|
|
return |
|
|
|
|
|
dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id")) |
|
|
test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id")) |
|
|
|
|
|
mode = annotation_data[0].get("_mode") |
|
|
if mode == "pt": |
|
|
files: List[str] = annotation_data[0]["_pt_files"] |
|
|
train_files: List[str] = [] |
|
|
val_files: List[str] = [] |
|
|
test_files: List[str] = [] |
|
|
|
|
|
|
|
|
dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id")) |
|
|
test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id")) |
|
|
|
|
|
for fp in files: |
|
|
sid = None |
|
|
try: |
|
|
obj = torch.load(fp, map_location="cpu") |
|
|
if isinstance(obj, dict): |
|
|
sid = obj.get("source_id") |
|
|
except Exception: |
|
|
pass |
|
|
if sid and (dev_ids or test_ids): |
|
|
if sid in test_ids: |
|
|
test_files.append(fp) |
|
|
elif sid in dev_ids: |
|
|
val_files.append(fp) |
|
|
else: |
|
|
train_files.append(fp) |
|
|
else: |
|
|
train_files.append(fp) |
|
|
|
|
|
|
|
|
self.train_dataset = _PtGraphDataset(train_files) |
|
|
self.val_dataset = _PtGraphDataset(val_files) |
|
|
self.test_dataset = _PtGraphDataset(test_files) |
|
|
|
|
|
if len(self.val_dataset) == 0 or len(self.test_dataset) == 0: |
|
|
raise RuntimeError( |
|
|
"KWDLC dev/test split produced empty val/test datasets. Ensure KWDLC id files exist and source_id is set in .pt files." |
|
|
) |
|
|
else: |
|
|
raise RuntimeError("Unsupported annotation mode; expected pt") |
|
|
|
|
|
print( |
|
|
f"Data split: train={len(self.train_dataset)}, val={len(self.val_dataset)}, test={len(self.test_dataset)}" |
|
|
) |
|
|
|
|
|
def _create_dataloader(self, dataset: List[Data], batch_size: int, shuffle: bool = False) -> DataLoader: |
|
|
"""Create a DataLoader with optional workers/prefetching.""" |
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
num_workers=self.num_workers, |
|
|
pin_memory=False, |
|
|
persistent_workers=True if self.num_workers > 0 else False, |
|
|
prefetch_factor=2 if self.num_workers > 0 else None, |
|
|
) |
|
|
|
|
|
def train_dataloader(self) -> DataLoader: |
|
|
"""Return train DataLoader.""" |
|
|
return self._create_dataloader(self.train_dataset, self.batch_size, shuffle=True) |
|
|
|
|
|
def val_dataloader(self) -> DataLoader: |
|
|
"""Return val DataLoader.""" |
|
|
return self._create_dataloader(self.val_dataset, self.batch_size, shuffle=False) |
|
|
|
|
|
def test_dataloader(self) -> DataLoader: |
|
|
"""Return test DataLoader.""" |
|
|
return self._create_dataloader(self.test_dataset, self.batch_size, shuffle=False) |
|
|
|