|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading model...", flush=True) |
|
|
|
|
|
import os |
|
|
import random |
|
|
from typing import Any, Dict, Optional, Tuple |
|
|
|
|
|
|
|
|
os.environ["WANDB_MODE"] = "disabled" |
|
|
|
|
|
from importlib import import_module |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
|
|
|
from mecari.analyzers.mecab import MeCabAnalyzer |
|
|
from mecari.data.data_module import DataModule |
|
|
from mecari.utils.morph_utils import build_adjacent_edges, dedup_morphemes, normalize_mecab_candidates |
|
|
|
|
|
|
|
|
def set_seed(seed: int = 42) -> None: |
|
|
"""Set random seeds for reproducibility during inference. |
|
|
|
|
|
Args: |
|
|
seed: Random seed value. |
|
|
""" |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
set_seed(42) |
|
|
|
|
|
|
|
|
def _find_best_checkpoint(checkpoints_dir: str, prefer_metric: str = "val_error") -> Tuple[Optional[str], float]: |
|
|
"""Find the best checkpoint file in a directory. |
|
|
|
|
|
Args: |
|
|
checkpoints_dir: Path to the checkpoints directory. |
|
|
prefer_metric: Preferred metric ("val_error" or "val_loss"). |
|
|
|
|
|
Returns: |
|
|
Tuple of (best checkpoint filename, score). |
|
|
""" |
|
|
checkpoint_files = [f for f in os.listdir(checkpoints_dir) if f.endswith(".ckpt")] |
|
|
if not checkpoint_files: |
|
|
return None, float("inf") |
|
|
|
|
|
best_checkpoint = None |
|
|
best_score = float("inf") |
|
|
|
|
|
|
|
|
for ckpt_file in checkpoint_files: |
|
|
if prefer_metric == "val_loss" and ("val_loss=" in ckpt_file or "val_loss_epoch=" in ckpt_file): |
|
|
try: |
|
|
if "val_loss_epoch=" in ckpt_file: |
|
|
score_str = ckpt_file.split("val_loss_epoch=")[-1].split(".ckpt")[0] |
|
|
else: |
|
|
score_str = ckpt_file.split("val_loss=")[-1].split(".ckpt")[0] |
|
|
score = float(score_str) |
|
|
if score < best_score: |
|
|
best_score = score |
|
|
best_checkpoint = ckpt_file |
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
elif prefer_metric == "val_error" and ("val_error=" in ckpt_file or "val_error_epoch=" in ckpt_file): |
|
|
try: |
|
|
if "val_error_epoch=" in ckpt_file: |
|
|
score_str = ckpt_file.split("val_error_epoch=")[-1].split(".ckpt")[0] |
|
|
else: |
|
|
score_str = ckpt_file.split("val_error=")[-1].split(".ckpt")[0] |
|
|
score = float(score_str) |
|
|
if score < best_score: |
|
|
best_score = score |
|
|
best_checkpoint = ckpt_file |
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
|
|
|
|
|
|
if not best_checkpoint: |
|
|
other_metric = "val_loss" if prefer_metric == "val_error" else "val_error" |
|
|
for ckpt_file in checkpoint_files: |
|
|
if other_metric == "val_loss" and "val_loss=" in ckpt_file: |
|
|
try: |
|
|
score_str = ckpt_file.split("val_loss=")[1].split("-loss.ckpt")[0] |
|
|
score = float(score_str) |
|
|
if score < best_score: |
|
|
best_score = score |
|
|
best_checkpoint = ckpt_file |
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
elif other_metric == "val_error" and "val_error=" in ckpt_file: |
|
|
try: |
|
|
score_str = ckpt_file.split("val_error=")[1].split(".ckpt")[0] |
|
|
score = float(score_str) |
|
|
if score < best_score: |
|
|
best_score = score |
|
|
best_checkpoint = ckpt_file |
|
|
except (ValueError, IndexError): |
|
|
pass |
|
|
|
|
|
|
|
|
if not best_checkpoint: |
|
|
for ckpt_file in sorted(checkpoint_files): |
|
|
if ckpt_file == "last.ckpt": |
|
|
continue |
|
|
try: |
|
|
stem = ckpt_file[:-5] if ckpt_file.endswith(".ckpt") else ckpt_file |
|
|
|
|
|
last_tok = stem.split("-")[-1] |
|
|
score = float(last_tok) |
|
|
if score < best_score: |
|
|
best_score = score |
|
|
best_checkpoint = ckpt_file |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if not best_checkpoint: |
|
|
if "last.ckpt" in checkpoint_files: |
|
|
best_checkpoint = "last.ckpt" |
|
|
else: |
|
|
best_checkpoint = sorted(checkpoint_files)[0] |
|
|
|
|
|
return best_checkpoint, best_score |
|
|
|
|
|
|
|
|
def _load_model_by_type(model_type: str, checkpoint_path: str) -> Any: |
|
|
"""Load the appropriate model class based on type. |
|
|
|
|
|
Args: |
|
|
model_type: Model type ("gat" or "gatv2"). |
|
|
checkpoint_path: Path to the checkpoint file. |
|
|
|
|
|
Returns: |
|
|
Loaded model instance. |
|
|
""" |
|
|
if model_type == "gatv2": |
|
|
cls = getattr(import_module("mecari.models.gatv2"), "MecariGATv2") |
|
|
model = cls.load_from_checkpoint(checkpoint_path, strict=False, map_location="cpu") |
|
|
|
|
|
model.eval() |
|
|
model.cpu() |
|
|
return model |
|
|
|
|
|
|
|
|
def _instantiate_model_from_config(config: Dict[str, Any]): |
|
|
"""Instantiate a model using config fields (no checkpoint loading).""" |
|
|
model_cfg = config.get("model", {}) |
|
|
training_cfg = config.get("training", {}) |
|
|
features_cfg = config.get("features", {}) |
|
|
|
|
|
if model_cfg.get("type") != "gatv2": |
|
|
raise ValueError(f"Unsupported model type: {model_cfg.get('type')}") |
|
|
|
|
|
MecariGATv2 = getattr(import_module("mecari.models.gatv2"), "MecariGATv2") |
|
|
model = MecariGATv2( |
|
|
hidden_dim=model_cfg.get("hidden_dim", 64), |
|
|
num_classes=model_cfg.get("num_classes", 1), |
|
|
learning_rate=training_cfg.get("learning_rate", 1e-3), |
|
|
lexical_feature_dim=features_cfg.get("lexical_feature_dim", 100000), |
|
|
num_heads=model_cfg.get("num_heads", 4), |
|
|
share_weights=model_cfg.get("share_weights", False), |
|
|
dropout=model_cfg.get("dropout", 0.1), |
|
|
attn_dropout=model_cfg.get("attn_dropout", model_cfg.get("attention_dropout", 0.1)), |
|
|
add_self_loops_flag=model_cfg.get("add_self_loops", True), |
|
|
edge_dropout=model_cfg.get("edge_dropout", 0.0), |
|
|
norm=model_cfg.get("norm", "layer"), |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def _load_model_from_state(config_path: str, state_path: str): |
|
|
"""Load model from a plain state_dict plus config.yaml.""" |
|
|
with open(config_path, "r", encoding="utf-8") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
model = _instantiate_model_from_config(config) |
|
|
state = torch.load(state_path, map_location="cpu") |
|
|
|
|
|
if ( |
|
|
isinstance(state, dict) |
|
|
and "state_dict" in state |
|
|
and all(k.startswith("model.") for k in state["state_dict"].keys()) |
|
|
): |
|
|
state = state["state_dict"] |
|
|
|
|
|
new_state = {} |
|
|
for k, v in state.items(): |
|
|
nk = k |
|
|
if k.startswith("model."): |
|
|
nk = k[len("model.") :] |
|
|
new_state[nk] = v |
|
|
model.load_state_dict(new_state, strict=False) |
|
|
model.eval() |
|
|
model.cpu() |
|
|
return model |
|
|
|
|
|
|
|
|
def load_model( |
|
|
experiment_name: Optional[str] = None, model_type: Optional[str] = None, prefer_metric: str = "val_error" |
|
|
) -> Optional[Tuple[Any, Dict[str, Any]]]: |
|
|
"""Load a trained model and its experiment info. |
|
|
|
|
|
Default behavior: load the single model under sample_model/. |
|
|
If --experiment is provided (or sample_model is unavailable), use experiments/. |
|
|
""" |
|
|
|
|
|
if not experiment_name: |
|
|
root = "sample_model" |
|
|
if os.path.exists(root): |
|
|
fixed_config = os.path.join(root, "config.yaml") |
|
|
state_path = os.path.join(root, "model.pt") |
|
|
if os.path.exists(fixed_config) and os.path.exists(state_path): |
|
|
try: |
|
|
with open(fixed_config, "r", encoding="utf-8") as f: |
|
|
config = yaml.safe_load(f) |
|
|
model = _load_model_from_state(fixed_config, state_path) |
|
|
experiment_info = { |
|
|
"name": os.path.basename(root), |
|
|
"path": root, |
|
|
"best_metric": None, |
|
|
"best_score": None, |
|
|
"model_type": config.get("model", {}).get("type", "unknown"), |
|
|
"best_model_path": state_path, |
|
|
"config": config, |
|
|
} |
|
|
return model, experiment_info |
|
|
except Exception as e: |
|
|
print(f"Failed to load sample model: {e}") |
|
|
return None |
|
|
print("sample_model/model.pt or config.yaml not found") |
|
|
return None |
|
|
else: |
|
|
print("sample_model directory not found") |
|
|
return None |
|
|
|
|
|
|
|
|
if experiment_name: |
|
|
exp_path = os.path.join("experiments", experiment_name) |
|
|
config_path = os.path.join(exp_path, "config.yaml") |
|
|
checkpoints_dir = os.path.join(exp_path, "checkpoints") |
|
|
|
|
|
if not os.path.exists(config_path) or not os.path.exists(checkpoints_dir): |
|
|
print(f"Experiment not found: {experiment_name}") |
|
|
return None |
|
|
|
|
|
try: |
|
|
with open(config_path, "r", encoding="utf-8") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
model_type_from_config = config.get("model", {}).get("type", "unknown") |
|
|
best_checkpoint, best_score = _find_best_checkpoint(checkpoints_dir, prefer_metric) |
|
|
|
|
|
if not best_checkpoint: |
|
|
print("No checkpoint found") |
|
|
return None |
|
|
|
|
|
metric_name = "val_loss" if prefer_metric == "val_loss" else "val_error" |
|
|
|
|
|
experiment_info = { |
|
|
"name": experiment_name, |
|
|
"path": exp_path, |
|
|
"val_error": best_score if prefer_metric == "val_error" else None, |
|
|
"val_loss": best_score if prefer_metric == "val_loss" else None, |
|
|
"best_metric": metric_name, |
|
|
"best_score": best_score, |
|
|
"model_type": model_type_from_config, |
|
|
"best_model_path": os.path.join(checkpoints_dir, best_checkpoint), |
|
|
"config": config, |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Failed to read experiment info: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
else: |
|
|
if not os.path.exists(experiments_dir): |
|
|
print("Experiments directory does not exist") |
|
|
return None |
|
|
|
|
|
experiments = [] |
|
|
for exp_dir in os.listdir(experiments_dir): |
|
|
exp_path = os.path.join(experiments_dir, exp_dir) |
|
|
config_path = os.path.join(exp_path, "config.yaml") |
|
|
checkpoints_dir = os.path.join(exp_path, "checkpoints") |
|
|
|
|
|
if not os.path.exists(config_path) or not os.path.exists(checkpoints_dir): |
|
|
continue |
|
|
|
|
|
try: |
|
|
with open(config_path, "r", encoding="utf-8") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
exp_model_type = config.get("model", {}).get("type", "unknown") |
|
|
|
|
|
if model_type and exp_model_type.lower() != model_type.lower(): |
|
|
continue |
|
|
|
|
|
best_checkpoint, best_score = _find_best_checkpoint(checkpoints_dir, prefer_metric) |
|
|
if best_checkpoint: |
|
|
metric_name = "val_loss" if prefer_metric == "val_loss" else "val_error" |
|
|
experiments.append( |
|
|
{ |
|
|
"name": exp_dir, |
|
|
"path": exp_path, |
|
|
"val_error": best_score if prefer_metric == "val_error" else None, |
|
|
"val_loss": best_score if prefer_metric == "val_loss" else None, |
|
|
"best_metric": metric_name, |
|
|
"best_score": best_score, |
|
|
"model_type": exp_model_type, |
|
|
"best_model_path": os.path.join(checkpoints_dir, best_checkpoint), |
|
|
"config": config, |
|
|
} |
|
|
) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
if not experiments: |
|
|
print("No available experiments found") |
|
|
return None |
|
|
|
|
|
experiment_info = min(experiments, key=lambda x: x["best_score"]) |
|
|
|
|
|
|
|
|
print(f"Loading model: {experiment_info['best_model_path']}") |
|
|
print(f"Experiment: {experiment_info['name']}") |
|
|
|
|
|
try: |
|
|
model = _load_model_by_type(experiment_info["model_type"], experiment_info["best_model_path"]) |
|
|
|
|
|
|
|
|
|
|
|
return model, experiment_info |
|
|
except Exception as e: |
|
|
print(f"Model loading error: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def viterbi_decode_from_morphemes(logits: torch.Tensor, morphemes: list, edges: list, silent: bool = False) -> list: |
|
|
"""Edge-based Viterbi decoding. |
|
|
|
|
|
Args: |
|
|
logits: Logits per morpheme. |
|
|
morphemes: List of morpheme records. |
|
|
edges: Edge list among morpheme indices. |
|
|
silent: If True, suppress debug prints. |
|
|
|
|
|
Returns: |
|
|
Indices of morphemes on the optimal path. |
|
|
""" |
|
|
if len(logits) != len(morphemes): |
|
|
if not silent: |
|
|
print(f"Warning: #logits ({len(logits)}) != #morphemes ({len(morphemes)})") |
|
|
return list(range(min(len(logits), len(morphemes)))) |
|
|
|
|
|
if not silent: |
|
|
print("\n=== Viterbi Decode ===") |
|
|
print(f"#Morphemes: {len(morphemes)}") |
|
|
print(f"Using edge info: {len(edges)} edges") |
|
|
|
|
|
print("\nNode logits:") |
|
|
for idx, (morph, logit) in enumerate(zip(morphemes, logits)): |
|
|
print( |
|
|
f" [{idx:3d}] {morph['surface']:10s} ({morph['start_pos']:2d}-{morph['end_pos']:2d}) {morph['pos']:10s} logit={logit:.3f}" |
|
|
) |
|
|
|
|
|
|
|
|
n = len(morphemes) |
|
|
adj_list = [[] for _ in range(n)] |
|
|
for edge in edges: |
|
|
source_idx = edge["source_idx"] |
|
|
target_idx = edge["target_idx"] |
|
|
if 0 <= source_idx < n and 0 <= target_idx < n: |
|
|
|
|
|
source_end = morphemes[source_idx].get("end_pos", 0) |
|
|
target_start = morphemes[target_idx].get("start_pos", 0) |
|
|
if source_end <= target_start: |
|
|
adj_list[source_idx].append(target_idx) |
|
|
|
|
|
|
|
|
pos_to_ud = { |
|
|
"名詞": "NOUN", |
|
|
"動詞": "VERB", |
|
|
"形容詞": "ADJ", |
|
|
"副詞": "ADV", |
|
|
"助詞": "ADP", |
|
|
"助動詞": "AUX", |
|
|
"接続詞": "CCONJ", |
|
|
"連体詞": "DET", |
|
|
"感動詞": "INTJ", |
|
|
"代名詞": "PRON", |
|
|
"形状詞": "ADJ", |
|
|
"補助記号": "PUNCT", |
|
|
"接頭辞": "PREFIX", |
|
|
"接尾辞": "SUFFIX", |
|
|
} |
|
|
|
|
|
if not silent: |
|
|
print("\nMorpheme details:") |
|
|
for i, morpheme in enumerate(morphemes): |
|
|
start_pos = morpheme.get("start_pos", 0) |
|
|
end_pos = morpheme.get("end_pos", 0) |
|
|
surface = morpheme.get("surface", "") |
|
|
logit = morpheme.get("logit", 0.0) |
|
|
pos = morpheme.get("pos", "") |
|
|
pos_main = pos.split(",")[0] if "," in pos else pos |
|
|
ud_pos = pos_to_ud.get(pos_main, "X") |
|
|
print(f" {i}: {surface} ({start_pos}-{end_pos}) {pos_main}({ud_pos}) logit={logit:.3f}") |
|
|
|
|
|
|
|
|
dp = [-float("inf")] * n |
|
|
parent = [-1] * n |
|
|
|
|
|
|
|
|
start_nodes = [] |
|
|
min_start_pos = min(m.get("start_pos", 0) for m in morphemes) |
|
|
for i, m in enumerate(morphemes): |
|
|
if m.get("start_pos", 0) == min_start_pos: |
|
|
start_nodes.append(i) |
|
|
|
|
|
|
|
|
for i in start_nodes: |
|
|
dp[i] = morphemes[i].get("logit", 0.0) |
|
|
|
|
|
|
|
|
node_positions = [(i, morphemes[i].get("start_pos", 0), morphemes[i].get("end_pos", 0)) for i in range(n)] |
|
|
node_positions.sort(key=lambda x: (x[1], x[2])) |
|
|
|
|
|
|
|
|
for node_idx, _, _ in node_positions: |
|
|
if dp[node_idx] == -float("inf"): |
|
|
continue |
|
|
|
|
|
|
|
|
for next_idx in adj_list[node_idx]: |
|
|
new_score = dp[node_idx] + morphemes[next_idx].get("logit", 0.0) |
|
|
if new_score > dp[next_idx]: |
|
|
dp[next_idx] = new_score |
|
|
parent[next_idx] = node_idx |
|
|
|
|
|
|
|
|
end_nodes = [] |
|
|
max_end_pos = max(m.get("end_pos", 0) for m in morphemes) |
|
|
for i, m in enumerate(morphemes): |
|
|
if m.get("end_pos", 0) == max_end_pos: |
|
|
end_nodes.append(i) |
|
|
|
|
|
best_end_idx = -1 |
|
|
best_score = -float("inf") |
|
|
for i in end_nodes: |
|
|
if dp[i] > best_score: |
|
|
best_score = dp[i] |
|
|
best_end_idx = i |
|
|
|
|
|
|
|
|
path = [] |
|
|
current = best_end_idx |
|
|
max_iterations = n * 2 |
|
|
iteration_count = 0 |
|
|
visited = set() |
|
|
|
|
|
while current != -1 and iteration_count < max_iterations: |
|
|
if current in visited: |
|
|
print(f"Warning: Detected cycle during backtracking (node {current})") |
|
|
break |
|
|
visited.add(current) |
|
|
path.append(current) |
|
|
current = parent[current] |
|
|
iteration_count += 1 |
|
|
|
|
|
if iteration_count >= max_iterations: |
|
|
print(f"Warning: Backtracking reached max iterations ({max_iterations})") |
|
|
|
|
|
path.reverse() |
|
|
|
|
|
|
|
|
if path: |
|
|
total_score = sum(morphemes[idx].get("logit", 0.0) for idx in path) |
|
|
if not silent: |
|
|
print(f"\nOptimal path (total score: {total_score:.3f}):") |
|
|
for idx in path: |
|
|
morpheme = morphemes[idx] |
|
|
logit = morpheme.get("logit", 0.0) |
|
|
print(f" {morpheme['surface']} (logit: {logit:.3f})") |
|
|
|
|
|
return path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_analyzer = None |
|
|
_data_module_cache = {} |
|
|
|
|
|
|
|
|
def predict_morphemes_from_text(text, model=None, experiment_info=None, silent=False): |
|
|
"""Predict morpheme boundaries from text. |
|
|
|
|
|
Steps: |
|
|
1. Analyze with MeCab to get candidates. |
|
|
2. Build nodes/edges from morphemes and connections. |
|
|
3. Run the model to get per-node scores. |
|
|
4. Run Viterbi decoding over nodes and edges. |
|
|
|
|
|
Args: |
|
|
text: Input text. |
|
|
model: Model to use. |
|
|
experiment_info: Experiment metadata. |
|
|
silent: If True, suppress prints. |
|
|
""" |
|
|
global _analyzer |
|
|
|
|
|
if model is None: |
|
|
result = load_model() |
|
|
if result is None: |
|
|
return [], [] |
|
|
model, experiment_info = result |
|
|
|
|
|
if not silent: |
|
|
print(f"Input text: {text}") |
|
|
|
|
|
|
|
|
if _analyzer is None: |
|
|
_analyzer = MeCabAnalyzer() |
|
|
|
|
|
|
|
|
candidates = _analyzer.get_morpheme_candidates(text) |
|
|
candidates = normalize_mecab_candidates(candidates) |
|
|
candidates = dedup_morphemes(candidates) |
|
|
|
|
|
if not candidates: |
|
|
print("Error: Failed to obtain morpheme candidates") |
|
|
return [], [] |
|
|
|
|
|
if not silent: |
|
|
print(f"#Candidates: {len(candidates)}") |
|
|
|
|
|
|
|
|
morphemes = candidates |
|
|
|
|
|
|
|
|
if not isinstance(morphemes, list): |
|
|
print(f"Warning: morphemes is not a list: {type(morphemes)}") |
|
|
morphemes = [] |
|
|
|
|
|
|
|
|
dm_tmp = DataModule(annotations_dir="dummy", batch_size=1, num_workers=0, lexical_feature_dim=100000, silent=True) |
|
|
morphemes = dm_tmp.compute_lexical_features(morphemes, text) |
|
|
|
|
|
|
|
|
edges = build_adjacent_edges(morphemes) |
|
|
|
|
|
|
|
|
for morpheme in morphemes: |
|
|
if "annotation" not in morpheme: |
|
|
morpheme["annotation"] = "?" |
|
|
|
|
|
if not silent: |
|
|
print(f"Unified graph: {len(morphemes)} nodes, {len(edges)} edges") |
|
|
|
|
|
|
|
|
features_config = experiment_info["config"].get("features", {}) |
|
|
training_config = experiment_info["config"].get("training", {}) |
|
|
edge_config = experiment_info["config"].get("edge_features", {}) |
|
|
|
|
|
|
|
|
global _data_module_cache |
|
|
cache_key = str(training_config.get("annotations_dir", "annotations_kwdlc")) |
|
|
|
|
|
if cache_key not in _data_module_cache: |
|
|
|
|
|
_data_module_cache[cache_key] = DataModule( |
|
|
annotations_dir=training_config.get("annotations_dir", "annotations_kwdlc"), |
|
|
batch_size=1, |
|
|
num_workers=0, |
|
|
silent=silent, |
|
|
lexical_feature_dim=features_config.get("lexical_feature_dim", 100000), |
|
|
use_bidirectional_edges=edge_config.get("use_bidirectional_edges", True), |
|
|
) |
|
|
|
|
|
data_module = _data_module_cache[cache_key] |
|
|
|
|
|
|
|
|
graph = data_module.create_graph_from_morphemes_data( |
|
|
morphemes=morphemes, |
|
|
edges=edges, |
|
|
text=text, |
|
|
for_training=False, |
|
|
) |
|
|
|
|
|
if graph is None: |
|
|
print("Error: Failed to create PyTorch graph") |
|
|
return [], [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
|
if experiment_info and "device" in experiment_info: |
|
|
device = experiment_info["device"] |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if not hasattr(graph, "lexical_indices") or graph.lexical_indices is None: |
|
|
print("Error: lexical_indices not found") |
|
|
return [], [] |
|
|
|
|
|
logits = model( |
|
|
graph.lexical_indices.to(device), |
|
|
graph.lexical_values.to(device), |
|
|
graph.edge_index.to(device), |
|
|
None, |
|
|
graph.edge_attr.to(device) if graph.edge_attr is not None else None, |
|
|
).squeeze() |
|
|
|
|
|
if logits.dim() == 0: |
|
|
logits = logits.unsqueeze(0) |
|
|
probabilities = torch.sigmoid(logits) |
|
|
predictions = (probabilities >= 0.5).float() |
|
|
|
|
|
|
|
|
logits = logits.cpu() |
|
|
probabilities = probabilities.cpu() |
|
|
predictions = predictions.cpu() |
|
|
|
|
|
|
|
|
for i, morpheme in enumerate(morphemes): |
|
|
if i < len(predictions): |
|
|
morpheme["predicted_annotation"] = "+" if predictions[i] == 1 else "-" |
|
|
morpheme["logit"] = logits[i].item() |
|
|
morpheme["probability"] = probabilities[i].item() |
|
|
|
|
|
|
|
|
optimal_path = viterbi_decode_from_morphemes(logits, morphemes, edges, silent=silent) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for i, morpheme in enumerate(morphemes): |
|
|
is_in_optimal_path = optimal_path and i in optimal_path |
|
|
|
|
|
result = { |
|
|
"surface": morpheme["surface"], |
|
|
"pos": morpheme["pos"], |
|
|
"reading": morpheme["reading"], |
|
|
"predicted_annotation": morpheme.get("predicted_annotation", "?"), |
|
|
"logit": morpheme.get("logit", 0.0), |
|
|
"probability": morpheme.get("probability", 0.5), |
|
|
"in_optimal_path": is_in_optimal_path, |
|
|
} |
|
|
|
|
|
results.append(result) |
|
|
|
|
|
|
|
|
optimal_morphemes = [] |
|
|
if optimal_path: |
|
|
|
|
|
position_candidates = {} |
|
|
for i, m in enumerate(morphemes): |
|
|
pos_key = (m.get("start_pos", 0), m.get("end_pos", 0)) |
|
|
if pos_key not in position_candidates: |
|
|
position_candidates[pos_key] = [] |
|
|
position_candidates[pos_key].append(i) |
|
|
|
|
|
for idx in optimal_path: |
|
|
if idx < len(morphemes): |
|
|
morph = morphemes[idx].copy() |
|
|
|
|
|
pos_key = (morph.get("start_pos", 0), morph.get("end_pos", 0)) |
|
|
if pos_key in position_candidates: |
|
|
candidates_at_pos = position_candidates[pos_key] |
|
|
morph["num_candidates"] = len(candidates_at_pos) |
|
|
morph["selected_rank"] = candidates_at_pos.index(idx) + 1 if idx in candidates_at_pos else 0 |
|
|
optimal_morphemes.append(morph) |
|
|
|
|
|
return results, optimal_morphemes |
|
|
|
|
|
|
|
|
def print_results(results, optimal_morphemes=None, verbose: bool = False): |
|
|
"""Print morphemes in MeCab-like format (surface\tCSV features).""" |
|
|
if not results: |
|
|
return |
|
|
|
|
|
def mecab_features(m): |
|
|
pos = m.get("pos", "*") |
|
|
pos1 = m.get("pos_detail1", "*") |
|
|
pos2 = m.get("pos_detail2", "*") |
|
|
ctype = m.get("inflection_type", "*") |
|
|
cform = m.get("inflection_form", "*") |
|
|
base = m.get("base_form", m.get("lemma", "*")) or "*" |
|
|
reading = m.get("reading", "*") or "*" |
|
|
return f"{pos},{pos1},{pos2},{ctype},{cform},{base},{reading}" |
|
|
|
|
|
items = ( |
|
|
optimal_morphemes |
|
|
if optimal_morphemes |
|
|
else [ |
|
|
{ |
|
|
"surface": r.get("surface", ""), |
|
|
"pos": r.get("pos", "*"), |
|
|
"pos_detail1": "*", |
|
|
"pos_detail2": "*", |
|
|
"inflection_type": "*", |
|
|
"inflection_form": "*", |
|
|
"base_form": r.get("surface", ""), |
|
|
"reading": r.get("reading", "*"), |
|
|
} |
|
|
for r in results |
|
|
] |
|
|
) |
|
|
|
|
|
for m in items: |
|
|
print(f"{m.get('surface', '')}\t{mecab_features(m)}") |
|
|
print("EOS") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main inference entrypoint.""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Mecari morphological analysis inference") |
|
|
parser.add_argument("--text", "-t", help="Input text directly") |
|
|
parser.add_argument("--experiment", "-e", help="Experiment name to load (e.g., gat_20250730_145624)") |
|
|
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output (include UD POS)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.experiment: |
|
|
result = load_model(experiment_name=args.experiment) |
|
|
else: |
|
|
result = load_model() |
|
|
|
|
|
if result is None: |
|
|
return |
|
|
|
|
|
model, experiment_info = result |
|
|
|
|
|
if args.text: |
|
|
result = predict_morphemes_from_text(args.text, model, experiment_info, silent=not args.verbose) |
|
|
if result: |
|
|
results, optimal_morphemes = result |
|
|
print_results(results, optimal_morphemes, verbose=args.verbose) |
|
|
else: |
|
|
print("Inference failed.") |
|
|
|
|
|
else: |
|
|
print("\nMecari morphological inference") |
|
|
print("Enter text (e.g., Tokyo is nice)") |
|
|
print("Type 'quit' or 'exit' to finish.\n") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("Input: ").strip() |
|
|
|
|
|
if user_input.lower() in ["quit", "exit", "q"]: |
|
|
print("Exiting.") |
|
|
break |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
print(f"Text: {user_input}") |
|
|
|
|
|
result = predict_morphemes_from_text(user_input, model, experiment_info, silent=not args.verbose) |
|
|
if result: |
|
|
results, optimal_morphemes = result |
|
|
print_results(results, optimal_morphemes, verbose=args.verbose) |
|
|
else: |
|
|
print("Inference failed.") |
|
|
|
|
|
print() |
|
|
|
|
|
except EOFError: |
|
|
print("\nExiting.") |
|
|
break |
|
|
except KeyboardInterrupt: |
|
|
print("\nExiting.") |
|
|
break |
|
|
except Exception as e: |
|
|
import traceback |
|
|
|
|
|
print(f"\nAn error occurred: {e}") |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|