| | |
| | """ |
| | Aphasia classification inference (cleaned). |
| | - Respects model_dir argument |
| | - Correctly parses durations like ["word", 300] and [start, end] |
| | - Removes duplicate load_state_dict |
| | - Adds predict_from_chajson(json_path, ...) helper |
| | """ |
| |
|
| | import json as json |
| | import os |
| | import math |
| | from dataclasses import dataclass |
| | from typing import Dict, List, Optional, Tuple |
| | from collections import defaultdict |
| |
|
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import pandas as pd |
| | from transformers import AutoTokenizer, AutoModel |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class ModelConfig: |
| | model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" |
| | max_length: int = 512 |
| | hidden_size: int = 768 |
| | pos_vocab_size: int = 150 |
| | pos_emb_dim: int = 64 |
| | grammar_dim: int = 3 |
| | grammar_hidden_dim: int = 64 |
| | duration_hidden_dim: int = 128 |
| | prosody_dim: int = 32 |
| | num_attention_heads: int = 8 |
| | attention_dropout: float = 0.3 |
| | classifier_hidden_dims: List[int] = None |
| | dropout_rate: float = 0.3 |
| | def __post_init__(self): |
| | if self.classifier_hidden_dims is None: |
| | self.classifier_hidden_dims = [512, 256] |
| |
|
| | class StablePositionalEncoding(nn.Module): |
| | def __init__(self, d_model: int, max_len: int = 5000): |
| | super().__init__() |
| | self.d_model = d_model |
| | pe = torch.zeros(max_len, d_model) |
| | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | self.register_buffer('pe', pe.unsqueeze(0)) |
| | self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01) |
| | def forward(self, x): |
| | seq_len = x.size(1) |
| | sinusoidal = self.pe[:, :seq_len, :].to(x.device) |
| | learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1) |
| | return x + 0.1 * (sinusoidal + learnable) |
| |
|
| | class StableMultiHeadAttention(nn.Module): |
| | def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.feature_dim = feature_dim |
| | self.head_dim = feature_dim // num_heads |
| | assert feature_dim % num_heads == 0 |
| | self.query = nn.Linear(feature_dim, feature_dim) |
| | self.key = nn.Linear(feature_dim, feature_dim) |
| | self.value = nn.Linear(feature_dim, feature_dim) |
| | self.dropout = nn.Dropout(dropout) |
| | self.output_proj = nn.Linear(feature_dim, feature_dim) |
| | self.layer_norm = nn.LayerNorm(feature_dim) |
| | def forward(self, x, mask=None): |
| | b, t, _ = x.size() |
| | Q = self.query(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) |
| | K = self.key(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) |
| | V = self.value(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) |
| | scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| | if mask is not None: |
| | if mask.dim() == 2: |
| | mask = mask.unsqueeze(1).unsqueeze(1) |
| | scores.masked_fill_(mask == 0, -1e9) |
| | attn = F.softmax(scores, dim=-1) |
| | attn = self.dropout(attn) |
| | ctx = torch.matmul(attn, V) |
| | ctx = ctx.transpose(1, 2).contiguous().view(b, t, self.feature_dim) |
| | out = self.output_proj(ctx) |
| | return self.layer_norm(out + x) |
| |
|
| | class StableLinguisticFeatureExtractor(nn.Module): |
| | def __init__(self, config: ModelConfig): |
| | super().__init__() |
| | self.config = config |
| | self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0) |
| | self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4) |
| | self.grammar_projection = nn.Sequential( |
| | nn.Linear(config.grammar_dim, config.grammar_hidden_dim), |
| | nn.Tanh(), |
| | nn.LayerNorm(config.grammar_hidden_dim), |
| | nn.Dropout(config.dropout_rate * 0.3) |
| | ) |
| | self.duration_projection = nn.Sequential( |
| | nn.Linear(1, config.duration_hidden_dim), |
| | nn.Tanh(), |
| | nn.LayerNorm(config.duration_hidden_dim) |
| | ) |
| | self.prosody_projection = nn.Sequential( |
| | nn.Linear(config.prosody_dim, config.prosody_dim), |
| | nn.ReLU(), |
| | nn.LayerNorm(config.prosody_dim) |
| | ) |
| | total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim + |
| | config.duration_hidden_dim + config.prosody_dim) |
| | self.feature_fusion = nn.Sequential( |
| | nn.Linear(total_feature_dim, total_feature_dim // 2), |
| | nn.Tanh(), |
| | nn.LayerNorm(total_feature_dim // 2), |
| | nn.Dropout(config.dropout_rate) |
| | ) |
| | def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask): |
| | b, t = pos_ids.size() |
| | pos_ids = pos_ids.clamp(0, self.config.pos_vocab_size - 1) |
| | pos_emb = self.pos_embedding(pos_ids) |
| | pos_feat = self.pos_attention(pos_emb, attention_mask) |
| | gra_feat = self.grammar_projection(grammar_ids.float()) |
| | dur_feat = self.duration_projection(durations.unsqueeze(-1).float()) |
| | pro_feat = self.prosody_projection(prosody_features.float()) |
| | combined = torch.cat([pos_feat, gra_feat, dur_feat, pro_feat], dim=-1) |
| | fused = self.feature_fusion(combined) |
| | mask_exp = attention_mask.unsqueeze(-1).float() |
| | pooled = torch.sum(fused * mask_exp, dim=1) / torch.sum(mask_exp, dim=1) |
| | return pooled |
| |
|
| | class StableAphasiaClassifier(nn.Module): |
| | def __init__(self, config: ModelConfig, num_labels: int): |
| | super().__init__() |
| | self.config = config |
| | self.num_labels = num_labels |
| | self.bert = AutoModel.from_pretrained(config.model_name) |
| | self.bert_config = self.bert.config |
| | self.positional_encoder = StablePositionalEncoding(d_model=self.bert_config.hidden_size, |
| | max_len=config.max_length) |
| | self.linguistic_extractor = StableLinguisticFeatureExtractor(config) |
| | bert_dim = self.bert_config.hidden_size |
| | lingu_dim = (config.pos_emb_dim + config.grammar_hidden_dim + |
| | config.duration_hidden_dim + config.prosody_dim) // 2 |
| | self.feature_fusion = nn.Sequential( |
| | nn.Linear(bert_dim + lingu_dim, bert_dim), |
| | nn.LayerNorm(bert_dim), |
| | nn.Tanh(), |
| | nn.Dropout(config.dropout_rate) |
| | ) |
| | self.classifier = self._build_classifier(bert_dim, num_labels) |
| | self.severity_head = nn.Sequential(nn.Linear(bert_dim, 4), nn.Softmax(dim=-1)) |
| | self.fluency_head = nn.Sequential(nn.Linear(bert_dim, 1), nn.Sigmoid()) |
| | def _build_classifier(self, input_dim: int, num_labels: int): |
| | layers, cur = [], input_dim |
| | for h in self.config.classifier_hidden_dims: |
| | layers += [nn.Linear(cur, h), nn.LayerNorm(h), nn.Tanh(), nn.Dropout(self.config.dropout_rate)] |
| | cur = h |
| | layers.append(nn.Linear(cur, num_labels)) |
| | return nn.Sequential(*layers) |
| | def _attention_pooling(self, seq_out, attn_mask): |
| | attn_w = torch.softmax(torch.sum(seq_out, dim=-1, keepdim=True), dim=1) |
| | attn_w = attn_w * attn_mask.unsqueeze(-1).float() |
| | attn_w = attn_w / (torch.sum(attn_w, dim=1, keepdim=True) + 1e-9) |
| | return torch.sum(seq_out * attn_w, dim=1) |
| | def forward(self, input_ids, attention_mask, labels=None, |
| | word_pos_ids=None, word_grammar_ids=None, word_durations=None, |
| | prosody_features=None, **kwargs): |
| | bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| | seq_out = bert_out.last_hidden_state |
| | pos_enh = self.positional_encoder(seq_out) |
| | pooled = self._attention_pooling(pos_enh, attention_mask) |
| | if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]): |
| | if prosody_features is None: |
| | b, t = input_ids.size() |
| | prosody_features = torch.zeros(b, t, self.config.prosody_dim, device=input_ids.device) |
| | ling = self.linguistic_extractor(word_pos_ids, word_grammar_ids, word_durations, |
| | prosody_features, attention_mask) |
| | else: |
| | ling = torch.zeros(input_ids.size(0), |
| | (self.config.pos_emb_dim + self.config.grammar_hidden_dim + |
| | self.config.duration_hidden_dim + self.config.prosody_dim) // 2, |
| | device=input_ids.device) |
| | fused = self.feature_fusion(torch.cat([pooled, ling], dim=1)) |
| | logits = self.classifier(fused) |
| | severity_pred = self.severity_head(fused) |
| | fluency_pred = self.fluency_head(fused) |
| | return {"logits": logits, "severity_pred": severity_pred, "fluency_pred": fluency_pred, "loss": None} |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class AphasiaInferenceSystem: |
| | """失語症分類推理系統""" |
| |
|
| | def __init__(self, model_dir: str): |
| | self.model_dir = model_dir |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | self.aphasia_descriptions = { |
| | "BROCA": {"name": "Broca's Aphasia (Non-fluent)", "description": |
| | "Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.", |
| | "features": ["Non-fluent speech", "Preserved comprehension", "Grammar difficulties", "Word-finding problems"]}, |
| | "TRANSMOTOR": {"name": "Trans-cortical Motor Aphasia", "description": |
| | "Similar to Broca's aphasia but with preserved repetition abilities. Speech is non-fluent with good comprehension.", |
| | "features": ["Non-fluent speech", "Good repetition", "Preserved comprehension", "Grammar difficulties"]}, |
| | "NOTAPHASICBYWAB": {"name": "Not Aphasic by WAB", "description": |
| | "Individuals who do not meet the criteria for aphasia according to the Western Aphasia Battery assessment.", |
| | "features": ["Normal language function", "No significant language impairment", "Good comprehension", "Fluent speech"]}, |
| | "CONDUCTION": {"name": "Conduction Aphasia", "description": |
| | "Characterized by fluent speech with good comprehension but severely impaired repetition. Often involves phonemic paraphasias.", |
| | "features": ["Fluent speech", "Good comprehension", "Poor repetition", "Phonemic errors"]}, |
| | "WERNICKE": {"name": "Wernicke's Aphasia (Fluent)", "description": |
| | "Fluent but often meaningless speech with poor comprehension. Speech may contain neologisms and jargon.", |
| | "features": ["Fluent speech", "Poor comprehension", "Jargon speech", "Neologisms"]}, |
| | "ANOMIC": {"name": "Anomic Aphasia", "description": |
| | "Primarily characterized by word-finding difficulties with otherwise relatively preserved language abilities.", |
| | "features": ["Word-finding difficulties", "Good comprehension", "Fluent speech", "Circumlocution"]}, |
| | "GLOBAL": {"name": "Global Aphasia", "description": |
| | "Severe impairment in all language modalities - comprehension, production, repetition, and naming.", |
| | "features": ["Severe comprehension deficit", "Non-fluent speech", "Poor repetition", "Severe naming difficulties"]}, |
| | "ISOLATION": {"name": "Isolation Syndrome", "description": |
| | "Rare condition with preserved repetition but severely impaired comprehension and spontaneous speech.", |
| | "features": ["Good repetition", "Poor comprehension", "Limited spontaneous speech", "Echolalia"]}, |
| | "TRANSSENSORY": {"name": "Trans-cortical Sensory Aphasia", "description": |
| | "Fluent speech with good repetition but impaired comprehension, similar to Wernicke's but with preserved repetition.", |
| | "features": ["Fluent speech", "Good repetition", "Poor comprehension", "Semantic errors"]} |
| | } |
| |
|
| | self.load_configuration() |
| | self.load_model() |
| | print(f"推理系統初始化完成,使用設備: {self.device}") |
| |
|
| | def load_configuration(self): |
| | cfg_path = os.path.join(self.model_dir, "config.json") |
| | if os.path.exists(cfg_path): |
| | with open(cfg_path, "r", encoding="utf-8") as f: |
| | cfg = json.load(f) |
| | self.aphasia_types_mapping = cfg.get("aphasia_types_mapping", { |
| | "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2, |
| | "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5, |
| | "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8 |
| | }) |
| | self.num_labels = cfg.get("num_labels", 9) |
| | self.model_name = cfg.get("model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") |
| | else: |
| | self.aphasia_types_mapping = { |
| | "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2, |
| | "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5, |
| | "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8 |
| | } |
| | self.num_labels = 9 |
| | self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" |
| | self.id_to_aphasia_type = {v: k for k, v in self.aphasia_types_mapping.items()} |
| |
|
| | def load_model(self): |
| | self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, use_fast=True) |
| | |
| | if self.tokenizer.pad_token is None: |
| | if self.tokenizer.eos_token is not None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | elif self.tokenizer.unk_token is not None: |
| | self.tokenizer.pad_token = self.tokenizer.unk_token |
| | else: |
| | self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
| | |
| | add_path = os.path.join(self.model_dir, "added_tokens.json") |
| | if os.path.exists(add_path): |
| | with open(add_path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| | tokens = list(data.keys()) if isinstance(data, dict) else data |
| | if tokens: |
| | self.tokenizer.add_tokens(tokens) |
| |
|
| | self.config = ModelConfig() |
| | self.config.model_name = self.model_name |
| |
|
| | self.model = StableAphasiaClassifier(self.config, self.num_labels) |
| | self.model.bert.resize_token_embeddings(len(self.tokenizer)) |
| |
|
| | model_path = os.path.join(self.model_dir, "pytorch_model.bin") |
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"模型權重文件不存在: {model_path}") |
| | state = torch.load(model_path, map_location=self.device) |
| | self.model.load_state_dict(state) |
| |
|
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | |
| |
|
| | def _dur_to_float(self, d) -> float: |
| | """Robustly parse duration from various shapes: |
| | - number |
| | - ["word", ms] |
| | - [start, end] |
| | - {"dur": ms} (future-proof) |
| | """ |
| | if isinstance(d, (int, float)): |
| | return float(d) |
| | if isinstance(d, list): |
| | if len(d) == 2: |
| | |
| | a, b = d[0], d[1] |
| | |
| | if isinstance(a, str) and isinstance(b, (int, float)): |
| | return float(b) |
| | |
| | if isinstance(a, (int, float)) and isinstance(b, (int, float)): |
| | return float(b) - float(a) |
| | if isinstance(d, dict): |
| | for k in ("dur", "duration", "ms"): |
| | if k in d and isinstance(d[k], (int, float)): |
| | return float(d[k]) |
| | return 0.0 |
| |
|
| | def _extract_prosodic_features(self, durations, tokens): |
| | vals = [] |
| | for d in durations: |
| | vals.append(self._dur_to_float(d)) |
| | vals = [v for v in vals if v > 0] |
| | if not vals: |
| | return [0.0] * self.config.prosody_dim |
| | features = [ |
| | float(np.mean(vals)), |
| | float(np.std(vals)), |
| | float(np.median(vals)), |
| | float(len([v for v in vals if v > (np.mean(vals) * 1.5)])), |
| | ] |
| | while len(features) < self.config.prosody_dim: |
| | features.append(0.0) |
| | return features[:self.config.prosody_dim] |
| |
|
| | def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded): |
| | |
| | subtoken_to_token = [] |
| | for idx, tok in enumerate(tokens): |
| | subtoks = self.tokenizer.tokenize(tok) |
| | subtoken_to_token.extend([idx] * max(1, len(subtoks))) |
| |
|
| | aligned_pos = [0] |
| | aligned_grammar = [[0, 0, 0]] |
| | aligned_durations = [0.0] |
| |
|
| | |
| | max_body = self.config.max_length - 2 |
| | for st_idx in range(max_body): |
| | if st_idx < len(subtoken_to_token): |
| | orig = subtoken_to_token[st_idx] |
| | aligned_pos.append(pos_ids[orig] if orig < len(pos_ids) else 0) |
| | aligned_grammar.append(grammar_ids[orig] if orig < len(grammar_ids) else [0, 0, 0]) |
| | aligned_durations.append(self._dur_to_float(durations[orig]) if orig < len(durations) else 0.0) |
| | else: |
| | aligned_pos.append(0) |
| | aligned_grammar.append([0, 0, 0]) |
| | aligned_durations.append(0.0) |
| |
|
| | aligned_pos.append(0) |
| | aligned_grammar.append([0, 0, 0]) |
| | aligned_durations.append(0.0) |
| | return aligned_pos, aligned_grammar, aligned_durations |
| |
|
| | def preprocess_sentence(self, sentence_data: dict) -> Optional[dict]: |
| | all_tokens, all_pos, all_grammar, all_durations = [], [], [], [] |
| | for d_idx, dialogue in enumerate(sentence_data.get("dialogues", [])): |
| | if d_idx > 0: |
| | all_tokens.append("[DIALOGUE]") |
| | all_pos.append(0) |
| | all_grammar.append([0, 0, 0]) |
| | all_durations.append(0.0) |
| | for par in dialogue.get("PAR", []): |
| | if "tokens" in par and par["tokens"]: |
| | toks = par["tokens"] |
| | pos_ids = par.get("word_pos_ids", [0] * len(toks)) |
| | gra_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(toks)) |
| | durs = par.get("word_durations", [0.0] * len(toks)) |
| | all_tokens.extend(toks) |
| | all_pos.extend(pos_ids) |
| | all_grammar.extend(gra_ids) |
| | all_durations.extend(durs) |
| | if not all_tokens: |
| | return None |
| |
|
| | text = " ".join(all_tokens) |
| | enc = self.tokenizer(text, max_length=self.config.max_length, padding="max_length", |
| | truncation=True, return_tensors="pt") |
| | aligned_pos, aligned_gra, aligned_dur = self._align_features( |
| | all_tokens, all_pos, all_grammar, all_durations, enc |
| | ) |
| | prosody = self._extract_prosodic_features(all_durations, all_tokens) |
| | prosody_tensor = torch.tensor(prosody).unsqueeze(0).repeat(self.config.max_length, 1) |
| |
|
| | return { |
| | "input_ids": enc["input_ids"].squeeze(0), |
| | "attention_mask": enc["attention_mask"].squeeze(0), |
| | "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long), |
| | "word_grammar_ids": torch.tensor(aligned_gra, dtype=torch.long), |
| | "word_durations": torch.tensor(aligned_dur, dtype=torch.float), |
| | "prosody_features": prosody_tensor.float(), |
| | "sentence_id": sentence_data.get("sentence_id", "unknown"), |
| | "original_tokens": all_tokens, |
| | "text": text |
| | } |
| |
|
| | def predict_single(self, sentence_data: dict) -> dict: |
| | proc = self.preprocess_sentence(sentence_data) |
| | if proc is None: |
| | return {"error": "無法處理輸入數據", "sentence_id": sentence_data.get("sentence_id", "unknown")} |
| | inp = { |
| | "input_ids": proc["input_ids"].unsqueeze(0).to(self.device), |
| | "attention_mask": proc["attention_mask"].unsqueeze(0).to(self.device), |
| | "word_pos_ids": proc["word_pos_ids"].unsqueeze(0).to(self.device), |
| | "word_grammar_ids": proc["word_grammar_ids"].unsqueeze(0).to(self.device), |
| | "word_durations": proc["word_durations"].unsqueeze(0).to(self.device), |
| | "prosody_features": proc["prosody_features"].unsqueeze(0).to(self.device), |
| | } |
| | with torch.no_grad(): |
| | out = self.model(**inp) |
| | logits = out["logits"] |
| | probs = F.softmax(logits, dim=1).cpu().numpy()[0] |
| | pred_id = int(np.argmax(probs)) |
| | sev = out["severity_pred"].cpu().numpy()[0] |
| | flu = float(out["fluency_pred"].cpu().numpy()[0][0]) |
| |
|
| | pred_type = self.id_to_aphasia_type[pred_id] |
| | conf = float(probs[pred_id]) |
| |
|
| | dist = {} |
| | for a_type, t_id in self.aphasia_types_mapping.items(): |
| | dist[a_type] = {"probability": float(probs[t_id]), "percentage": f"{probs[t_id]*100:.2f}%"} |
| |
|
| | sorted_dist = dict(sorted(dist.items(), key=lambda x: x[1]["probability"], reverse=True)) |
| | return { |
| | "sentence_id": proc["sentence_id"], |
| | "input_text": proc["text"], |
| | "original_tokens": proc["original_tokens"], |
| | "prediction": { |
| | "predicted_class": pred_type, |
| | "confidence": conf, |
| | "confidence_percentage": f"{conf*100:.2f}%" |
| | }, |
| | "class_description": self.aphasia_descriptions.get(pred_type, { |
| | "name": pred_type, "description": "Description not available", "features": [] |
| | }), |
| | "probability_distribution": sorted_dist, |
| | "additional_predictions": { |
| | "severity_distribution": { |
| | "level_0": float(sev[0]), "level_1": float(sev[1]), |
| | "level_2": float(sev[2]), "level_3": float(sev[3]) |
| | }, |
| | "predicted_severity_level": int(np.argmax(sev)), |
| | "fluency_score": flu, |
| | "fluency_rating": "High" if flu > 0.7 else ("Medium" if flu > 0.4 else "Low"), |
| | } |
| | } |
| |
|
| | def predict_batch(self, input_file: str, output_file: Optional[str] = None) -> Dict: |
| | with open(input_file, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| | sentences = data.get("sentences", []) |
| | results = [] |
| | print(f"開始處理 {len(sentences)} 個句子...") |
| | for i, s in enumerate(sentences): |
| | print(f"處理第 {i+1}/{len(sentences)} 個句子...") |
| | results.append(self.predict_single(s)) |
| | summary = self._generate_summary(results) |
| | final = {"summary": summary, "total_sentences": len(results), "predictions": results} |
| | if output_file: |
| | with open(output_file, "w", encoding="utf-8") as f: |
| | json.dump(final, f, ensure_ascii=False, indent=2) |
| | print(f"結果已保存到: {output_file}") |
| | return final |
| |
|
| | def _generate_summary(self, results: List[dict]) -> dict: |
| | if not results: |
| | return {} |
| | class_counts = defaultdict(int) |
| | confs, flus = [], [] |
| | sev_counts = defaultdict(int) |
| | for r in results: |
| | if "error" in r: |
| | continue |
| | c = r["prediction"]["predicted_class"] |
| | class_counts[c] += 1 |
| | confs.append(r["prediction"]["confidence"]) |
| | flus.append(r["additional_predictions"]["fluency_score"]) |
| | sev_counts[r["additional_predictions"]["predicted_severity_level"]] += 1 |
| | avg_conf = float(np.mean(confs)) if confs else 0.0 |
| | avg_flu = float(np.mean(flus)) if flus else 0.0 |
| | return { |
| | "classification_distribution": dict(class_counts), |
| | "classification_percentages": {k: f"{v/len(results)*100:.1f}%" for k, v in class_counts.items()}, |
| | "average_confidence": f"{avg_conf:.3f}", |
| | "average_fluency_score": f"{avg_flu:.3f}", |
| | "severity_distribution": dict(sev_counts), |
| | "confidence_statistics": {} if not confs else { |
| | "mean": f"{np.mean(confs):.3f}", |
| | "std": f"{np.std(confs):.3f}", |
| | "min": f"{np.min(confs):.3f}", |
| | "max": f"{np.max(confs):.3f}", |
| | }, |
| | "most_common_prediction": max(class_counts.items(), key=lambda x: x[1])[0] if class_counts else "None", |
| | } |
| |
|
| | def generate_detailed_report(self, results: List[dict], output_dir: str = "./inference_results"): |
| | os.makedirs(output_dir, exist_ok=True) |
| | rows = [] |
| | for r in results: |
| | if "error" in r: |
| | continue |
| | row = { |
| | "sentence_id": r["sentence_id"], |
| | "predicted_class": r["prediction"]["predicted_class"], |
| | "confidence": r["prediction"]["confidence"], |
| | "class_name": r["class_description"]["name"], |
| | "severity_level": r["additional_predictions"]["predicted_severity_level"], |
| | "fluency_score": r["additional_predictions"]["fluency_score"], |
| | "fluency_rating": r["additional_predictions"]["fluency_rating"], |
| | "input_text": r["input_text"], |
| | } |
| | for a_type, info in r["probability_distribution"].items(): |
| | row[f"prob_{a_type}"] = info["probability"] |
| | rows.append(row) |
| | if not rows: |
| | return None |
| | df = pd.DataFrame(rows) |
| | df.to_csv(os.path.join(output_dir, "detailed_predictions.csv"), index=False, encoding="utf-8") |
| | summary_stats = { |
| | "total_predictions": int(len(rows)), |
| | "class_distribution": df["predicted_class"].value_counts().to_dict(), |
| | "average_confidence": float(df["confidence"].mean()), |
| | "confidence_std": float(df["confidence"].std()), |
| | "average_fluency": float(df["fluency_score"].mean()), |
| | "fluency_std": float(df["fluency_score"].std()), |
| | "severity_distribution": df["severity_level"].value_counts().to_dict(), |
| | } |
| | with open(os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8") as f: |
| | json.dump(summary_stats, f, ensure_ascii=False, indent=2) |
| | print(f"詳細報告已生成並保存到: {output_dir}") |
| | return df |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def predict_from_chajson(model_dir: str, chajson_path: str, output_file: Optional[str] = None) -> Dict: |
| | """ |
| | Convenience entry: |
| | - Accepts the JSON produced by cha_json.py |
| | - If it contains 'sentences', runs per-sentence like before |
| | - If it only contains 'text_all', creates a single pseudo-sentence |
| | """ |
| | with open(chajson_path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| |
|
| | inf = AphasiaInferenceSystem(model_dir) |
| |
|
| | |
| | if data.get("sentences"): |
| | return inf.predict_batch(chajson_path, output_file=output_file) |
| |
|
| | |
| | text_all = data.get("text_all", "") |
| | fake = { |
| | "sentences": [{ |
| | "sentence_id": "S1", |
| | "dialogues": [{ |
| | "INV": [], |
| | "PAR": [{"tokens": text_all.split(), |
| | "word_pos_ids": [0]*len(text_all.split()), |
| | "word_grammar_ids": [[0,0,0]]*len(text_all.split()), |
| | "word_durations": [0.0]*len(text_all.split())}] |
| | }] |
| | }] |
| | } |
| | tmp_path = chajson_path + "._synthetic.json" |
| | with open(tmp_path, "w", encoding="utf-8") as f: |
| | json.dump(fake, f, ensure_ascii=False, indent=2) |
| | out = inf.predict_batch(tmp_path, output_file=output_file) |
| | try: |
| | os.remove(tmp_path) |
| | except Exception: |
| | pass |
| | return out |
| |
|
| | def format_result(pred: dict, style: str = "json") -> str: |
| | """Back-compat formatter. 'pred' is the dict returned by predict_*.""" |
| | if style == "json": |
| | return json.dumps(pred, ensure_ascii=False, indent=2) |
| | |
| | if isinstance(pred, dict) and "summary" in pred: |
| | s = pred["summary"] |
| | lines = [ |
| | f"Total sentences: {pred.get('total_sentences', 0)}", |
| | f"Avg confidence: {s.get('average_confidence', 'N/A')}", |
| | f"Avg fluency: {s.get('average_fluency_score', 'N/A')}", |
| | f"Most common: {s.get('most_common_prediction', 'N/A')}", |
| | ] |
| | return "\n".join(lines) |
| | return str(pred) |
| |
|
| |
|
| | |
| |
|
| | def main(): |
| | import argparse |
| | p = argparse.ArgumentParser(description="失語症分類推理系統") |
| | p.add_argument("--model_dir", type=str, required=False, default="./adaptive_aphasia_model", |
| | help="訓練好的模型目錄路徑") |
| | p.add_argument("--input_file", type=str, required=True, |
| | help="輸入JSON文件(cha_json 的輸出)") |
| | p.add_argument("--output_file", type=str, default="./aphasia_predictions.json", |
| | help="輸出JSON文件路徑") |
| | p.add_argument("--report_dir", type=str, default="./inference_results", |
| | help="詳細報告輸出目錄") |
| | p.add_argument("--generate_report", action="store_true", |
| | help="是否生成詳細的CSV報告") |
| | args = p.parse_args() |
| |
|
| | try: |
| | print("正在初始化推理系統...") |
| | sys = AphasiaInferenceSystem(args.model_dir) |
| |
|
| | print("開始執行批次預測...") |
| | results = sys.predict_batch(args.input_file, args.output_file) |
| |
|
| | if args.generate_report: |
| | print("生成詳細報告...") |
| | sys.generate_detailed_report(results["predictions"], args.report_dir) |
| |
|
| | print("\n=== 預測摘要 ===") |
| | s = results["summary"] |
| | print(f"總句子數: {results['total_sentences']}") |
| | print(f"平均信心度: {s.get('average_confidence', 'N/A')}") |
| | print(f"平均流利度: {s.get('average_fluency_score', 'N/A')}") |
| | print(f"最常見預測: {s.get('most_common_prediction', 'N/A')}") |
| | print("\n類別分佈:") |
| | for name, count in s.get("classification_distribution", {}).items(): |
| | pct = s.get("classification_percentages", {}).get(name, "0%") |
| | print(f" {name}: {count} ({pct})") |
| | print(f"\n結果已保存到: {args.output_file}") |
| | except Exception as e: |
| | print(f"錯誤: {str(e)}") |
| | import traceback; traceback.print_exc() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|