Mecari / mecari /data /data_module.py
zbller's picture
Upload folder using huggingface_hub
34c8a90 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from typing import Dict, List, Optional
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data, DataLoader
# Required import for lexical feature computation
from mecari.featurizers.lexical import (
LexicalNGramFeaturizer as LexFeaturizer,
Morpheme as LexMorpheme,
)
"""Data module for lexical-graph training using prebuilt .pt graphs only."""
# Prebuilt .pt graph dataset
class _PtGraphDataset(Dataset):
"""Prebuilt PyG graph tensors saved as .pt per sentence.
Each file is expected to be a dict with keys:
- 'graph': torch_geometric.data.Data
- 'source_id': str (used for split)
- optional: 'text'
"""
def __init__(self, files: List[str]) -> None:
self.files = files
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx: int) -> Data:
path = self.files[idx]
obj = torch.load(path, map_location="cpu")
if isinstance(obj, dict) and "graph" in obj:
data = obj["graph"]
else:
data = obj
if not isinstance(data, Data):
raise RuntimeError(f"Invalid graph object in: {path}")
data.data_index = idx
return data
# Safe globals registration for PyTorch 2.6+
try:
import torch.serialization
from torch_geometric.data.data import DataEdgeAttr
torch.serialization.add_safe_globals([DataEdgeAttr, Data])
except (ImportError, AttributeError):
pass
class DataModule(pl.LightningDataModule):
"""Loads .pt graphs and builds lexical graph features for training."""
def __init__(
self,
annotations_dir: str = "annotations",
batch_size: int = 32,
num_workers: int = 0,
max_files: Optional[int] = None,
use_bidirectional_edges: bool = True,
annotations_override_dir: Optional[str] = None,
silent: bool = False,
lexical_feature_dim: int = 100000,
lexical_max_features: int = 20,
) -> None:
super().__init__()
self.annotations_dir = annotations_dir
self.annotations_override_dir = annotations_override_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.max_files = max_files
self.use_bidirectional_edges = True
self.silent = silent
self.lexical_feature_dim = lexical_feature_dim
self.lexical_max_features = int(lexical_max_features)
self.use_bidirectional_edges = bool(use_bidirectional_edges)
# Initialized in setup()
self.train_dataset = []
self.val_dataset = []
self.test_dataset = []
# Eagerly initialize lexical featurizer (small and picklable)
self._lex_featurizer = LexFeaturizer(dim=int(self.lexical_feature_dim), add_bias=True)
# POS mapping for evaluation breakdown
self.pos_to_id = {
"名詞": 1,
"動詞": 2,
"形容詞": 3,
"副詞": 4,
"助詞": 5,
"助動詞": 6,
"接続詞": 7,
"連体詞": 8,
"感動詞": 9,
"形状詞": 10,
"補助記号": 11,
"接頭辞": 12,
"接尾辞": 13,
"特殊": 14,
}
self.id_to_pos = {v: k for k, v in self.pos_to_id.items()}
def create_graph_from_morphemes_data(self, *args, **kwargs) -> Optional[Data]:
"""Create a lexical graph from morpheme data (or candidates)."""
if "candidates" in kwargs:
candidates = kwargs.pop("candidates")
text = kwargs.get("text", "")
morphemes_edges = self._build_graph_from_candidates(candidates, text)
if not morphemes_edges:
return None
kwargs["morphemes"] = morphemes_edges["morphemes"]
kwargs["edges"] = morphemes_edges["edges"]
return self._create_lexical_graph(*args, **kwargs)
# --- Lexical features helper (for preprocessing) ---
def compute_lexical_features(self, morphemes: List[Dict], text: str) -> List[Dict]:
"""Add lexical_features to each morpheme using Mecari's lexical featurizer.
Requires mecari.featurizers.lexical to be importable. Raises a clear error
if the featurizer is unavailable (training/inference depend on it).
"""
if not morphemes:
return morphemes
for m in morphemes:
try:
morph_obj = LexMorpheme(
surf=m.get("surface", ""),
lemma=m.get("base_form", ""),
pos=m.get("pos", "*"),
pos1=m.get("pos_detail1", "*"),
ctype=m.get("inflection_type", "*"),
cform=m.get("inflection_form", "*"),
reading=m.get("reading", "*"),
)
st = m.get("start_pos", 0)
ed = m.get("end_pos", st + len(m.get("surface", "")))
prev_char = text[st - 1] if st > 0 else None
next_char = text[ed] if ed < len(text) else None
feats = self._lex_featurizer.unigram_feats(morph_obj, prev_char, next_char)
m["lexical_features"] = feats
except Exception:
# on any failure, leave unchanged
pass
return morphemes
def _create_lexical_graph(
self, morphemes: List[Dict], edges: List[Dict], text: str, for_training: bool = True
) -> Optional[Data]:
"""Build a graph using lexical features."""
if not morphemes:
return None
# Sparse lexical features per node
all_indices = []
all_values = []
all_lengths = []
annotations = []
valid_mask = []
max_features = 0
for morpheme in morphemes:
lexical_feats = morpheme.get("lexical_features", [])
indices = []
values = []
for idx, val in lexical_feats:
if 0 <= idx < self.lexical_feature_dim:
indices.append(idx)
values.append(val)
all_lengths.append(len(indices))
max_features = max(max_features, len(indices))
all_indices.append(indices)
all_values.append(values)
if for_training:
annotation = morpheme.get("annotation", "?")
if annotation == "+":
annotations.append(1)
valid_mask.append(True)
elif annotation == "-":
annotations.append(0)
valid_mask.append(True)
else:
annotations.append(0)
valid_mask.append(False)
# Fixed-size padding/truncation for batching
FIXED_MAX_FEATURES = int(getattr(self, "lexical_max_features", 20))
padded_indices = []
padded_values = []
for indices, values in zip(all_indices, all_values):
if len(indices) > FIXED_MAX_FEATURES:
padded_indices.append(indices[:FIXED_MAX_FEATURES])
padded_values.append(values[:FIXED_MAX_FEATURES])
else:
pad_length = FIXED_MAX_FEATURES - len(indices)
padded_indices.append(indices + [0] * pad_length)
padded_values.append(values + [0.0] * pad_length)
edge_index = self._build_edge_index(edges, len(morphemes))
# POS ids per node (for evaluation breakdown)
pos_ids = []
for m in morphemes:
pos = m.get("pos", "*")
pos_ids.append(self.pos_to_id.get(pos, 0))
graph_data = Data(
lexical_indices=torch.tensor(padded_indices, dtype=torch.long),
lexical_values=torch.tensor(padded_values, dtype=torch.float32),
lexical_lengths=torch.tensor(all_lengths, dtype=torch.long),
edge_index=edge_index,
num_nodes=len(morphemes),
)
graph_data.pos_ids = torch.tensor(pos_ids, dtype=torch.long)
if for_training:
graph_data.y = torch.tensor(annotations, dtype=torch.float32)
graph_data.valid_mask = torch.tensor(valid_mask, dtype=torch.bool)
return graph_data
def _build_edge_index(self, edges: List[Dict], num_nodes: int) -> torch.Tensor:
"""Build a PyG edge_index tensor from edge dicts."""
if not edges:
return torch.tensor([[], []], dtype=torch.long)
source_indices = []
target_indices = []
for edge in edges:
source = edge.get("source_idx", 0)
target = edge.get("target_idx", 0)
if 0 <= source < num_nodes and 0 <= target < num_nodes:
source_indices.append(source)
target_indices.append(target)
if self.use_bidirectional_edges:
source_indices.append(target)
target_indices.append(source)
if not source_indices:
return torch.tensor([[], []], dtype=torch.long)
return torch.tensor([source_indices, target_indices], dtype=torch.long)
def _load_kwdlc_ids(self, ids_file: str) -> set:
"""Load KWDLC ID list (one ID per line)."""
ids = set()
if ids_file and os.path.exists(ids_file):
with open(ids_file, "r") as f:
for line in f:
ids.add(line.strip())
return ids
def load_annotation_data(self, max_files: Optional[int] = None) -> List[Dict]:
"""Detect and list available .pt annotation graph files."""
if os.path.isdir(self.annotations_dir):
pt_files = [
os.path.join(self.annotations_dir, fn)
for fn in sorted(os.listdir(self.annotations_dir))
if fn.endswith(".pt")
]
if pt_files:
if max_files is not None:
pt_files = pt_files[:max_files]
return [{"_mode": "pt", "_pt_files": pt_files}]
raise FileNotFoundError(f"No annotation graphs found under: {self.annotations_dir}")
def setup(self, stage: Optional[str] = None) -> None:
"""Build train/val/test datasets from discovered .pt files."""
annotation_data = self.load_annotation_data(max_files=self.max_files)
if not annotation_data:
self.train_dataset = []
self.val_dataset = []
self.test_dataset = []
return
dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id"))
test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id"))
mode = annotation_data[0].get("_mode")
if mode == "pt":
files: List[str] = annotation_data[0]["_pt_files"]
train_files: List[str] = []
val_files: List[str] = []
test_files: List[str] = []
# Use KWDLC split ids (mandatory)
dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id"))
test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id"))
for fp in files:
sid = None
try:
obj = torch.load(fp, map_location="cpu")
if isinstance(obj, dict):
sid = obj.get("source_id")
except Exception:
pass
if sid and (dev_ids or test_ids):
if sid in test_ids:
test_files.append(fp)
elif sid in dev_ids:
val_files.append(fp)
else:
train_files.append(fp)
else:
train_files.append(fp)
# Build datasets strictly based on KWDLC dev/test ids
self.train_dataset = _PtGraphDataset(train_files)
self.val_dataset = _PtGraphDataset(val_files)
self.test_dataset = _PtGraphDataset(test_files)
if len(self.val_dataset) == 0 or len(self.test_dataset) == 0:
raise RuntimeError(
"KWDLC dev/test split produced empty val/test datasets. Ensure KWDLC id files exist and source_id is set in .pt files."
)
else:
raise RuntimeError("Unsupported annotation mode; expected pt")
print(
f"Data split: train={len(self.train_dataset)}, val={len(self.val_dataset)}, test={len(self.test_dataset)}"
)
def _create_dataloader(self, dataset: List[Data], batch_size: int, shuffle: bool = False) -> DataLoader:
"""Create a DataLoader with optional workers/prefetching."""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
pin_memory=False,
persistent_workers=True if self.num_workers > 0 else False,
prefetch_factor=2 if self.num_workers > 0 else None,
)
def train_dataloader(self) -> DataLoader:
"""Return train DataLoader."""
return self._create_dataloader(self.train_dataset, self.batch_size, shuffle=True)
def val_dataloader(self) -> DataLoader:
"""Return val DataLoader."""
return self._create_dataloader(self.val_dataset, self.batch_size, shuffle=False)
def test_dataloader(self) -> DataLoader:
"""Return test DataLoader."""
return self._create_dataloader(self.test_dataset, self.batch_size, shuffle=False)