Rename Json__Output.py to output.py
Browse files- Json__Output.py +0 -896
- output.py +642 -0
Json__Output.py
DELETED
|
@@ -1,896 +0,0 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
"""
|
| 3 |
-
失語症分類推理系統
|
| 4 |
-
用於載入訓練好的模型並對新的語音數據進行分類預測
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import json
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
import numpy as np
|
| 12 |
-
import os
|
| 13 |
-
import math
|
| 14 |
-
from typing import Dict, List, Optional, Tuple
|
| 15 |
-
from dataclasses import dataclass
|
| 16 |
-
import pandas as pd
|
| 17 |
-
from transformers import AutoTokenizer, AutoModel
|
| 18 |
-
from collections import defaultdict
|
| 19 |
-
|
| 20 |
-
# 重新定義模型結構(與訓練程式碼一致)
|
| 21 |
-
@dataclass
|
| 22 |
-
class ModelConfig:
|
| 23 |
-
model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
| 24 |
-
max_length: int = 512
|
| 25 |
-
hidden_size: int = 768
|
| 26 |
-
pos_vocab_size: int = 150
|
| 27 |
-
pos_emb_dim: int = 64
|
| 28 |
-
grammar_dim: int = 3
|
| 29 |
-
grammar_hidden_dim: int = 64
|
| 30 |
-
duration_hidden_dim: int = 128
|
| 31 |
-
prosody_dim: int = 32
|
| 32 |
-
num_attention_heads: int = 8
|
| 33 |
-
attention_dropout: float = 0.3
|
| 34 |
-
classifier_hidden_dims: List[int] = None
|
| 35 |
-
dropout_rate: float = 0.3
|
| 36 |
-
|
| 37 |
-
def __post_init__(self):
|
| 38 |
-
if self.classifier_hidden_dims is None:
|
| 39 |
-
self.classifier_hidden_dims = [512, 256]
|
| 40 |
-
|
| 41 |
-
class StablePositionalEncoding(nn.Module):
|
| 42 |
-
def __init__(self, d_model: int, max_len: int = 5000):
|
| 43 |
-
super().__init__()
|
| 44 |
-
self.d_model = d_model
|
| 45 |
-
|
| 46 |
-
pe = torch.zeros(max_len, d_model)
|
| 47 |
-
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 48 |
-
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
| 49 |
-
(-math.log(10000.0) / d_model))
|
| 50 |
-
|
| 51 |
-
pe[:, 0::2] = torch.sin(position * div_term)
|
| 52 |
-
pe[:, 1::2] = torch.cos(position * div_term)
|
| 53 |
-
|
| 54 |
-
self.register_buffer('pe', pe.unsqueeze(0))
|
| 55 |
-
self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
|
| 56 |
-
|
| 57 |
-
def forward(self, x):
|
| 58 |
-
seq_len = x.size(1)
|
| 59 |
-
sinusoidal = self.pe[:, :seq_len, :].to(x.device)
|
| 60 |
-
learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
|
| 61 |
-
return x + 0.1 * (sinusoidal + learnable)
|
| 62 |
-
|
| 63 |
-
class StableMultiHeadAttention(nn.Module):
|
| 64 |
-
def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
|
| 65 |
-
super().__init__()
|
| 66 |
-
self.num_heads = num_heads
|
| 67 |
-
self.feature_dim = feature_dim
|
| 68 |
-
self.head_dim = feature_dim // num_heads
|
| 69 |
-
|
| 70 |
-
assert feature_dim % num_heads == 0
|
| 71 |
-
|
| 72 |
-
self.query = nn.Linear(feature_dim, feature_dim)
|
| 73 |
-
self.key = nn.Linear(feature_dim, feature_dim)
|
| 74 |
-
self.value = nn.Linear(feature_dim, feature_dim)
|
| 75 |
-
self.dropout = nn.Dropout(dropout)
|
| 76 |
-
self.output_proj = nn.Linear(feature_dim, feature_dim)
|
| 77 |
-
self.layer_norm = nn.LayerNorm(feature_dim)
|
| 78 |
-
|
| 79 |
-
def forward(self, x, mask=None):
|
| 80 |
-
batch_size, seq_len, _ = x.size()
|
| 81 |
-
|
| 82 |
-
Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 83 |
-
K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 84 |
-
V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 85 |
-
|
| 86 |
-
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 87 |
-
|
| 88 |
-
if mask is not None:
|
| 89 |
-
if mask.dim() == 2:
|
| 90 |
-
mask = mask.unsqueeze(1).unsqueeze(1)
|
| 91 |
-
scores.masked_fill_(mask == 0, -1e9)
|
| 92 |
-
|
| 93 |
-
attn_weights = F.softmax(scores, dim=-1)
|
| 94 |
-
attn_weights = self.dropout(attn_weights)
|
| 95 |
-
|
| 96 |
-
context = torch.matmul(attn_weights, V)
|
| 97 |
-
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)
|
| 98 |
-
|
| 99 |
-
output = self.output_proj(context)
|
| 100 |
-
return self.layer_norm(output + x)
|
| 101 |
-
|
| 102 |
-
class StableLinguisticFeatureExtractor(nn.Module):
|
| 103 |
-
def __init__(self, config: ModelConfig):
|
| 104 |
-
super().__init__()
|
| 105 |
-
self.config = config
|
| 106 |
-
|
| 107 |
-
self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0)
|
| 108 |
-
self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4)
|
| 109 |
-
|
| 110 |
-
self.grammar_projection = nn.Sequential(
|
| 111 |
-
nn.Linear(config.grammar_dim, config.grammar_hidden_dim),
|
| 112 |
-
nn.Tanh(),
|
| 113 |
-
nn.LayerNorm(config.grammar_hidden_dim),
|
| 114 |
-
nn.Dropout(config.dropout_rate * 0.3)
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
self.duration_projection = nn.Sequential(
|
| 118 |
-
nn.Linear(1, config.duration_hidden_dim),
|
| 119 |
-
nn.Tanh(),
|
| 120 |
-
nn.LayerNorm(config.duration_hidden_dim)
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
self.prosody_projection = nn.Sequential(
|
| 124 |
-
nn.Linear(config.prosody_dim, config.prosody_dim),
|
| 125 |
-
nn.ReLU(),
|
| 126 |
-
nn.LayerNorm(config.prosody_dim)
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
|
| 130 |
-
config.duration_hidden_dim + config.prosody_dim)
|
| 131 |
-
self.feature_fusion = nn.Sequential(
|
| 132 |
-
nn.Linear(total_feature_dim, total_feature_dim // 2),
|
| 133 |
-
nn.Tanh(),
|
| 134 |
-
nn.LayerNorm(total_feature_dim // 2),
|
| 135 |
-
nn.Dropout(config.dropout_rate)
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
|
| 139 |
-
batch_size, seq_len = pos_ids.size()
|
| 140 |
-
|
| 141 |
-
pos_ids_clamped = pos_ids.clamp(0, self.config.pos_vocab_size - 1)
|
| 142 |
-
pos_embeds = self.pos_embedding(pos_ids_clamped)
|
| 143 |
-
pos_features = self.pos_attention(pos_embeds, attention_mask)
|
| 144 |
-
|
| 145 |
-
grammar_features = self.grammar_projection(grammar_ids.float())
|
| 146 |
-
duration_features = self.duration_projection(durations.unsqueeze(-1).float())
|
| 147 |
-
prosody_features = self.prosody_projection(prosody_features.float())
|
| 148 |
-
|
| 149 |
-
combined_features = torch.cat([
|
| 150 |
-
pos_features, grammar_features, duration_features, prosody_features
|
| 151 |
-
], dim=-1)
|
| 152 |
-
|
| 153 |
-
fused_features = self.feature_fusion(combined_features)
|
| 154 |
-
|
| 155 |
-
mask_expanded = attention_mask.unsqueeze(-1).float()
|
| 156 |
-
pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1)
|
| 157 |
-
|
| 158 |
-
return pooled_features
|
| 159 |
-
|
| 160 |
-
class StableAphasiaClassifier(nn.Module):
|
| 161 |
-
def __init__(self, config: ModelConfig, num_labels: int):
|
| 162 |
-
super().__init__()
|
| 163 |
-
self.config = config
|
| 164 |
-
self.num_labels = num_labels
|
| 165 |
-
|
| 166 |
-
self.bert = AutoModel.from_pretrained(config.model_name)
|
| 167 |
-
self.bert_config = self.bert.config
|
| 168 |
-
|
| 169 |
-
self.positional_encoder = StablePositionalEncoding(
|
| 170 |
-
d_model=self.bert_config.hidden_size,
|
| 171 |
-
max_len=config.max_length
|
| 172 |
-
)
|
| 173 |
-
|
| 174 |
-
self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
|
| 175 |
-
|
| 176 |
-
bert_dim = self.bert_config.hidden_size
|
| 177 |
-
linguistic_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
|
| 178 |
-
config.duration_hidden_dim + config.prosody_dim) // 2
|
| 179 |
-
|
| 180 |
-
self.feature_fusion = nn.Sequential(
|
| 181 |
-
nn.Linear(bert_dim + linguistic_dim, bert_dim),
|
| 182 |
-
nn.LayerNorm(bert_dim),
|
| 183 |
-
nn.Tanh(),
|
| 184 |
-
nn.Dropout(config.dropout_rate)
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
self.classifier = self._build_classifier(bert_dim, num_labels)
|
| 188 |
-
|
| 189 |
-
self.severity_head = nn.Sequential(
|
| 190 |
-
nn.Linear(bert_dim, 4),
|
| 191 |
-
nn.Softmax(dim=-1)
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
self.fluency_head = nn.Sequential(
|
| 195 |
-
nn.Linear(bert_dim, 1),
|
| 196 |
-
nn.Sigmoid()
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
def _build_classifier(self, input_dim: int, num_labels: int):
|
| 200 |
-
layers = []
|
| 201 |
-
current_dim = input_dim
|
| 202 |
-
|
| 203 |
-
for hidden_dim in self.config.classifier_hidden_dims:
|
| 204 |
-
layers.extend([
|
| 205 |
-
nn.Linear(current_dim, hidden_dim),
|
| 206 |
-
nn.LayerNorm(hidden_dim),
|
| 207 |
-
nn.Tanh(),
|
| 208 |
-
nn.Dropout(self.config.dropout_rate)
|
| 209 |
-
])
|
| 210 |
-
current_dim = hidden_dim
|
| 211 |
-
|
| 212 |
-
layers.append(nn.Linear(current_dim, num_labels))
|
| 213 |
-
return nn.Sequential(*layers)
|
| 214 |
-
|
| 215 |
-
def forward(self, input_ids, attention_mask, labels=None,
|
| 216 |
-
word_pos_ids=None, word_grammar_ids=None, word_durations=None,
|
| 217 |
-
prosody_features=None, **kwargs):
|
| 218 |
-
|
| 219 |
-
bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 220 |
-
sequence_output = bert_outputs.last_hidden_state
|
| 221 |
-
|
| 222 |
-
position_enhanced = self.positional_encoder(sequence_output)
|
| 223 |
-
pooled_output = self._attention_pooling(position_enhanced, attention_mask)
|
| 224 |
-
|
| 225 |
-
if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
|
| 226 |
-
if prosody_features is None:
|
| 227 |
-
batch_size, seq_len = input_ids.size()
|
| 228 |
-
prosody_features = torch.zeros(
|
| 229 |
-
batch_size, seq_len, self.config.prosody_dim,
|
| 230 |
-
device=input_ids.device
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
linguistic_features = self.linguistic_extractor(
|
| 234 |
-
word_pos_ids, word_grammar_ids, word_durations,
|
| 235 |
-
prosody_features, attention_mask
|
| 236 |
-
)
|
| 237 |
-
else:
|
| 238 |
-
linguistic_features = torch.zeros(
|
| 239 |
-
input_ids.size(0),
|
| 240 |
-
(self.config.pos_emb_dim + self.config.grammar_hidden_dim +
|
| 241 |
-
self.config.duration_hidden_dim + self.config.prosody_dim) // 2,
|
| 242 |
-
device=input_ids.device
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
combined_features = torch.cat([pooled_output, linguistic_features], dim=1)
|
| 246 |
-
fused_features = self.feature_fusion(combined_features)
|
| 247 |
-
|
| 248 |
-
logits = self.classifier(fused_features)
|
| 249 |
-
severity_pred = self.severity_head(fused_features)
|
| 250 |
-
fluency_pred = self.fluency_head(fused_features)
|
| 251 |
-
|
| 252 |
-
return {
|
| 253 |
-
"logits": logits,
|
| 254 |
-
"severity_pred": severity_pred,
|
| 255 |
-
"fluency_pred": fluency_pred,
|
| 256 |
-
"loss": None
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
def _attention_pooling(self, sequence_output, attention_mask):
|
| 260 |
-
attention_weights = torch.softmax(
|
| 261 |
-
torch.sum(sequence_output, dim=-1, keepdim=True), dim=1
|
| 262 |
-
)
|
| 263 |
-
attention_weights = attention_weights * attention_mask.unsqueeze(-1).float()
|
| 264 |
-
attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9)
|
| 265 |
-
pooled = torch.sum(sequence_output * attention_weights, dim=1)
|
| 266 |
-
return pooled
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
class AphasiaInferenceSystem:
|
| 270 |
-
"""失語症分類推理系統"""
|
| 271 |
-
|
| 272 |
-
def __init__(self, model_dir: str):
|
| 273 |
-
"""
|
| 274 |
-
初始化推理系統
|
| 275 |
-
Args:
|
| 276 |
-
model_dir: 訓練好的模型目錄路徑
|
| 277 |
-
"""
|
| 278 |
-
self.model_dir = '/workspace/SH001/adaptive_aphasia_model'
|
| 279 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 280 |
-
|
| 281 |
-
# 失語症類型描述
|
| 282 |
-
self.aphasia_descriptions = {
|
| 283 |
-
"BROCA": {
|
| 284 |
-
"name": "Broca's Aphasia (Non-fluent)",
|
| 285 |
-
"description": "Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.",
|
| 286 |
-
"features": ["Non-fluent speech", "Preserved comprehension", "Grammar difficulties", "Word-finding problems"]
|
| 287 |
-
},
|
| 288 |
-
"TRANSMOTOR": {
|
| 289 |
-
"name": "Trans-cortical Motor Aphasia",
|
| 290 |
-
"description": "Similar to Broca's aphasia but with preserved repetition abilities. Speech is non-fluent with good comprehension.",
|
| 291 |
-
"features": ["Non-fluent speech", "Good repetition", "Preserved comprehension", "Grammar difficulties"]
|
| 292 |
-
},
|
| 293 |
-
"NOTAPHASICBYWAB": {
|
| 294 |
-
"name": "Not Aphasic by WAB",
|
| 295 |
-
"description": "Individuals who do not meet the criteria for aphasia according to the Western Aphasia Battery assessment.",
|
| 296 |
-
"features": ["Normal language function", "No significant language impairment", "Good comprehension", "Fluent speech"]
|
| 297 |
-
},
|
| 298 |
-
"CONDUCTION": {
|
| 299 |
-
"name": "Conduction Aphasia",
|
| 300 |
-
"description": "Characterized by fluent speech with good comprehension but severely impaired repetition. Often involves phonemic paraphasias.",
|
| 301 |
-
"features": ["Fluent speech", "Good comprehension", "Poor repetition", "Phonemic errors"]
|
| 302 |
-
},
|
| 303 |
-
"WERNICKE": {
|
| 304 |
-
"name": "Wernicke's Aphasia (Fluent)",
|
| 305 |
-
"description": "Fluent but often meaningless speech with poor comprehension. Speech may contain neologisms and jargon.",
|
| 306 |
-
"features": ["Fluent speech", "Poor comprehension", "Jargon speech", "Neologisms"]
|
| 307 |
-
},
|
| 308 |
-
"ANOMIC": {
|
| 309 |
-
"name": "Anomic Aphasia",
|
| 310 |
-
"description": "Primarily characterized by word-finding difficulties with otherwise relatively preserved language abilities.",
|
| 311 |
-
"features": ["Word-finding difficulties", "Good comprehension", "Fluent speech", "Circumlocution"]
|
| 312 |
-
},
|
| 313 |
-
"GLOBAL": {
|
| 314 |
-
"name": "Global Aphasia",
|
| 315 |
-
"description": "Severe impairment in all language modalities - comprehension, production, repetition, and naming.",
|
| 316 |
-
"features": ["Severe comprehension deficit", "Non-fluent speech", "Poor repetition", "Severe naming difficulties"]
|
| 317 |
-
},
|
| 318 |
-
"ISOLATION": {
|
| 319 |
-
"name": "Isolation Syndrome",
|
| 320 |
-
"description": "Rare condition with preserved repetition but severely impaired comprehension and spontaneous speech.",
|
| 321 |
-
"features": ["Good repetition", "Poor comprehension", "Limited spontaneous speech", "Echolalia"]
|
| 322 |
-
},
|
| 323 |
-
"TRANSSENSORY": {
|
| 324 |
-
"name": "Trans-cortical Sensory Aphasia",
|
| 325 |
-
"description": "Fluent speech with good repetition but impaired comprehension, similar to Wernicke's but with preserved repetition.",
|
| 326 |
-
"features": ["Fluent speech", "Good repetition", "Poor comprehension", "Semantic errors"]
|
| 327 |
-
}
|
| 328 |
-
}
|
| 329 |
-
|
| 330 |
-
# 載入模型配置和映射
|
| 331 |
-
self.load_configuration()
|
| 332 |
-
|
| 333 |
-
# 載入模型
|
| 334 |
-
self.load_model()
|
| 335 |
-
|
| 336 |
-
print(f"推理系統初始化完成,使用設備: {self.device}")
|
| 337 |
-
|
| 338 |
-
def load_configuration(self):
|
| 339 |
-
"""載入模型配置"""
|
| 340 |
-
config_path = os.path.join(self.model_dir, "config.json")
|
| 341 |
-
if os.path.exists(config_path):
|
| 342 |
-
with open(config_path, "r", encoding="utf-8") as f:
|
| 343 |
-
config_data = json.load(f)
|
| 344 |
-
|
| 345 |
-
self.aphasia_types_mapping = config_data.get("aphasia_types_mapping", {
|
| 346 |
-
"BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
|
| 347 |
-
"CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
|
| 348 |
-
"GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
|
| 349 |
-
})
|
| 350 |
-
self.num_labels = config_data.get("num_labels", 9)
|
| 351 |
-
self.model_name = config_data.get("model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
|
| 352 |
-
else:
|
| 353 |
-
# 預設配置
|
| 354 |
-
self.aphasia_types_mapping = {
|
| 355 |
-
"BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
|
| 356 |
-
"CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
|
| 357 |
-
"GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
|
| 358 |
-
}
|
| 359 |
-
self.num_labels = 9
|
| 360 |
-
self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
| 361 |
-
|
| 362 |
-
# 建立反向映射
|
| 363 |
-
self.id_to_aphasia_type = {v: k for k, v in self.aphasia_types_mapping.items()}
|
| 364 |
-
|
| 365 |
-
def load_model(self):
|
| 366 |
-
"""載入訓練好的模型"""
|
| 367 |
-
# 載入tokenizer
|
| 368 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
| 369 |
-
if self.tokenizer.pad_token is None:
|
| 370 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 371 |
-
added_tokens_path = os.path.join(self.model_dir, "added_tokens.json")
|
| 372 |
-
if os.path.exists(added_tokens_path):
|
| 373 |
-
with open(added_tokens_path, "r", encoding="utf-8") as f:
|
| 374 |
-
data = json.load(f)
|
| 375 |
-
# 如果是 dict,就取出所有 key 當作要新增的 token 清單
|
| 376 |
-
if isinstance(data, dict):
|
| 377 |
-
tokens = list(data.keys())
|
| 378 |
-
else:
|
| 379 |
-
tokens = data # 萬一已經是 list,就直接用
|
| 380 |
-
num_added = self.tokenizer.add_tokens(tokens)
|
| 381 |
-
print(f"新增到 tokenizer 的 token 數量: {num_added}")
|
| 382 |
-
# 建立模型配置
|
| 383 |
-
self.config = ModelConfig()
|
| 384 |
-
self.config.model_name = self.model_name
|
| 385 |
-
|
| 386 |
-
# 建立模型
|
| 387 |
-
self.model = StableAphasiaClassifier(self.config, self.num_labels)
|
| 388 |
-
self.model.bert.resize_token_embeddings(len(self.tokenizer))
|
| 389 |
-
# 載入模型權重
|
| 390 |
-
model_path = os.path.join(self.model_dir, "pytorch_model.bin")
|
| 391 |
-
if os.path.exists(model_path):
|
| 392 |
-
state_dict = torch.load(model_path, map_location=self.device)
|
| 393 |
-
self.model.load_state_dict(state_dict)
|
| 394 |
-
self.model.load_state_dict(state_dict)
|
| 395 |
-
print("模型權重載入成功")
|
| 396 |
-
else:
|
| 397 |
-
raise FileNotFoundError(f"模型權重文件不存在: {model_path}")
|
| 398 |
-
|
| 399 |
-
# 調整tokenizer尺寸
|
| 400 |
-
self.model.bert.resize_token_embeddings(len(self.tokenizer))
|
| 401 |
-
|
| 402 |
-
# 移動到設備並設置為評估模式
|
| 403 |
-
self.model.to(self.device)
|
| 404 |
-
self.model.eval()
|
| 405 |
-
|
| 406 |
-
def preprocess_sentence(self, sentence_data: dict) -> dict:
|
| 407 |
-
"""預處理單個句子數據"""
|
| 408 |
-
all_tokens, all_pos, all_grammar, all_durations = [], [], [], []
|
| 409 |
-
|
| 410 |
-
# 處理對話數據
|
| 411 |
-
for dialogue_idx, dialogue in enumerate(sentence_data.get("dialogues", [])):
|
| 412 |
-
if dialogue_idx > 0:
|
| 413 |
-
all_tokens.append("[DIALOGUE]")
|
| 414 |
-
all_pos.append(0)
|
| 415 |
-
all_grammar.append([0, 0, 0])
|
| 416 |
-
all_durations.append(0.0)
|
| 417 |
-
|
| 418 |
-
# 處理參與者的語音
|
| 419 |
-
for par in dialogue.get("PAR", []):
|
| 420 |
-
if "tokens" in par and par["tokens"]:
|
| 421 |
-
tokens = par["tokens"]
|
| 422 |
-
pos_ids = par.get("word_pos_ids", [0] * len(tokens))
|
| 423 |
-
grammar_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(tokens))
|
| 424 |
-
durations = par.get("word_durations", [0.0] * len(tokens))
|
| 425 |
-
|
| 426 |
-
all_tokens.extend(tokens)
|
| 427 |
-
all_pos.extend(pos_ids)
|
| 428 |
-
all_grammar.extend(grammar_ids)
|
| 429 |
-
all_durations.extend(durations)
|
| 430 |
-
|
| 431 |
-
if not all_tokens:
|
| 432 |
-
return None
|
| 433 |
-
|
| 434 |
-
# 文本tokenization
|
| 435 |
-
text = " ".join(all_tokens)
|
| 436 |
-
encoded = self.tokenizer(
|
| 437 |
-
text,
|
| 438 |
-
max_length=self.config.max_length,
|
| 439 |
-
padding="max_length",
|
| 440 |
-
truncation=True,
|
| 441 |
-
return_tensors="pt"
|
| 442 |
-
)
|
| 443 |
-
|
| 444 |
-
# 對齊特徵
|
| 445 |
-
aligned_pos, aligned_grammar, aligned_durations = self._align_features(
|
| 446 |
-
all_tokens, all_pos, all_grammar, all_durations, encoded
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
# 建立韻律特徵
|
| 450 |
-
prosody_features = self._extract_prosodic_features(all_durations, all_tokens)
|
| 451 |
-
prosody_tensor = torch.tensor(prosody_features).unsqueeze(0).repeat(
|
| 452 |
-
self.config.max_length, 1
|
| 453 |
-
)
|
| 454 |
-
|
| 455 |
-
return {
|
| 456 |
-
"input_ids": encoded["input_ids"].squeeze(0),
|
| 457 |
-
"attention_mask": encoded["attention_mask"].squeeze(0),
|
| 458 |
-
"word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long),
|
| 459 |
-
"word_grammar_ids": torch.tensor(aligned_grammar, dtype=torch.long),
|
| 460 |
-
"word_durations": torch.tensor(aligned_durations, dtype=torch.float),
|
| 461 |
-
"prosody_features": prosody_tensor.float(),
|
| 462 |
-
"sentence_id": sentence_data.get("sentence_id", "unknown"),
|
| 463 |
-
"original_tokens": all_tokens,
|
| 464 |
-
"text": text
|
| 465 |
-
}
|
| 466 |
-
|
| 467 |
-
def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded):
|
| 468 |
-
"""對齊特徵與BERT子詞"""
|
| 469 |
-
subtoken_to_token = []
|
| 470 |
-
|
| 471 |
-
for token_idx, token in enumerate(tokens):
|
| 472 |
-
subtokens = self.tokenizer.tokenize(token)
|
| 473 |
-
subtoken_to_token.extend([token_idx] * len(subtokens))
|
| 474 |
-
|
| 475 |
-
aligned_pos = [0] # [CLS]
|
| 476 |
-
aligned_grammar = [[0, 0, 0]] # [CLS]
|
| 477 |
-
aligned_durations = [0.0] # [CLS]
|
| 478 |
-
|
| 479 |
-
for subtoken_idx in range(1, self.config.max_length - 1):
|
| 480 |
-
if subtoken_idx - 1 < len(subtoken_to_token):
|
| 481 |
-
original_idx = subtoken_to_token[subtoken_idx - 1]
|
| 482 |
-
aligned_pos.append(pos_ids[original_idx] if original_idx < len(pos_ids) else 0)
|
| 483 |
-
aligned_grammar.append(grammar_ids[original_idx] if original_idx < len(grammar_ids) else [0, 0, 0])
|
| 484 |
-
|
| 485 |
-
# 處理duration數據
|
| 486 |
-
raw_duration = durations[original_idx] if original_idx < len(durations) else 0.0
|
| 487 |
-
if isinstance(raw_duration, list) and len(raw_duration) >= 2:
|
| 488 |
-
try:
|
| 489 |
-
duration_val = float(raw_duration[1]) - float(raw_duration[0])
|
| 490 |
-
except (ValueError, TypeError):
|
| 491 |
-
duration_val = 0.0
|
| 492 |
-
elif isinstance(raw_duration, (int, float)):
|
| 493 |
-
duration_val = float(raw_duration)
|
| 494 |
-
else:
|
| 495 |
-
duration_val = 0.0
|
| 496 |
-
|
| 497 |
-
aligned_durations.append(duration_val)
|
| 498 |
-
else:
|
| 499 |
-
aligned_pos.append(0)
|
| 500 |
-
aligned_grammar.append([0, 0, 0])
|
| 501 |
-
aligned_durations.append(0.0)
|
| 502 |
-
|
| 503 |
-
aligned_pos.append(0) # [SEP]
|
| 504 |
-
aligned_grammar.append([0, 0, 0]) # [SEP]
|
| 505 |
-
aligned_durations.append(0.0) # [SEP]
|
| 506 |
-
|
| 507 |
-
return aligned_pos, aligned_grammar, aligned_durations
|
| 508 |
-
|
| 509 |
-
def _extract_prosodic_features(self, durations, tokens):
|
| 510 |
-
"""提取韻律特徵"""
|
| 511 |
-
if not durations:
|
| 512 |
-
return [0.0] * self.config.prosody_dim
|
| 513 |
-
|
| 514 |
-
# 處理duration數據並提取數值
|
| 515 |
-
processed_durations = []
|
| 516 |
-
for d in durations:
|
| 517 |
-
if isinstance(d, list) and len(d) >= 2:
|
| 518 |
-
try:
|
| 519 |
-
processed_durations.append(float(d[1]) - float(d[0]))
|
| 520 |
-
except (ValueError, TypeError):
|
| 521 |
-
continue
|
| 522 |
-
elif isinstance(d, (int, float)):
|
| 523 |
-
processed_durations.append(float(d))
|
| 524 |
-
|
| 525 |
-
if not processed_durations:
|
| 526 |
-
return [0.0] * self.config.prosody_dim
|
| 527 |
-
|
| 528 |
-
# 計算基本統計特徵
|
| 529 |
-
features = [
|
| 530 |
-
np.mean(processed_durations),
|
| 531 |
-
np.std(processed_durations),
|
| 532 |
-
np.median(processed_durations),
|
| 533 |
-
len([d for d in processed_durations if d > np.mean(processed_durations) * 1.5])
|
| 534 |
-
]
|
| 535 |
-
|
| 536 |
-
# 填充至所需維度
|
| 537 |
-
while len(features) < self.config.prosody_dim:
|
| 538 |
-
features.append(0.0)
|
| 539 |
-
|
| 540 |
-
return features[:self.config.prosody_dim]
|
| 541 |
-
|
| 542 |
-
def predict_single(self, sentence_data: dict) -> dict:
|
| 543 |
-
"""對單個句子進行預測"""
|
| 544 |
-
# 預處理數據
|
| 545 |
-
processed_data = self.preprocess_sentence(sentence_data)
|
| 546 |
-
if processed_data is None:
|
| 547 |
-
return {
|
| 548 |
-
"error": "無法處理輸入數據",
|
| 549 |
-
"sentence_id": sentence_data.get("sentence_id", "unknown")
|
| 550 |
-
}
|
| 551 |
-
|
| 552 |
-
# 準備輸入數據
|
| 553 |
-
input_data = {
|
| 554 |
-
"input_ids": processed_data["input_ids"].unsqueeze(0).to(self.device),
|
| 555 |
-
"attention_mask": processed_data["attention_mask"].unsqueeze(0).to(self.device),
|
| 556 |
-
"word_pos_ids": processed_data["word_pos_ids"].unsqueeze(0).to(self.device),
|
| 557 |
-
"word_grammar_ids": processed_data["word_grammar_ids"].unsqueeze(0).to(self.device),
|
| 558 |
-
"word_durations": processed_data["word_durations"].unsqueeze(0).to(self.device),
|
| 559 |
-
"prosody_features": processed_data["prosody_features"].unsqueeze(0).to(self.device)
|
| 560 |
-
}
|
| 561 |
-
|
| 562 |
-
# 模型推理
|
| 563 |
-
with torch.no_grad():
|
| 564 |
-
outputs = self.model(**input_data)
|
| 565 |
-
|
| 566 |
-
logits = outputs["logits"]
|
| 567 |
-
probabilities = F.softmax(logits, dim=1).cpu().numpy()[0]
|
| 568 |
-
predicted_class_id = np.argmax(probabilities)
|
| 569 |
-
|
| 570 |
-
severity_pred = outputs["severity_pred"].cpu().numpy()[0]
|
| 571 |
-
fluency_pred = outputs["fluency_pred"].cpu().numpy()[0][0]
|
| 572 |
-
|
| 573 |
-
# 建立結果
|
| 574 |
-
predicted_type = self.id_to_aphasia_type[predicted_class_id]
|
| 575 |
-
confidence = float(probabilities[predicted_class_id])
|
| 576 |
-
|
| 577 |
-
# 建立機率分佈
|
| 578 |
-
probability_distribution = {}
|
| 579 |
-
for aphasia_type, type_id in self.aphasia_types_mapping.items():
|
| 580 |
-
probability_distribution[aphasia_type] = {
|
| 581 |
-
"probability": float(probabilities[type_id]),
|
| 582 |
-
"percentage": f"{probabilities[type_id]*100:.2f}%"
|
| 583 |
-
}
|
| 584 |
-
|
| 585 |
-
# 排序機率分佈
|
| 586 |
-
sorted_probabilities = sorted(
|
| 587 |
-
probability_distribution.items(),
|
| 588 |
-
key=lambda x: x[1]["probability"],
|
| 589 |
-
reverse=True
|
| 590 |
-
)
|
| 591 |
-
|
| 592 |
-
result = {
|
| 593 |
-
"sentence_id": processed_data["sentence_id"],
|
| 594 |
-
"input_text": processed_data["text"],
|
| 595 |
-
"original_tokens": processed_data["original_tokens"],
|
| 596 |
-
"prediction": {
|
| 597 |
-
"predicted_class": predicted_type,
|
| 598 |
-
"confidence": confidence,
|
| 599 |
-
"confidence_percentage": f"{confidence*100:.2f}%"
|
| 600 |
-
},
|
| 601 |
-
"class_description": self.aphasia_descriptions.get(predicted_type, {
|
| 602 |
-
"name": predicted_type,
|
| 603 |
-
"description": "Description not available",
|
| 604 |
-
"features": []
|
| 605 |
-
}),
|
| 606 |
-
"probability_distribution": dict(sorted_probabilities),
|
| 607 |
-
"additional_predictions": {
|
| 608 |
-
"severity_distribution": {
|
| 609 |
-
"level_0": float(severity_pred[0]),
|
| 610 |
-
"level_1": float(severity_pred[1]),
|
| 611 |
-
"level_2": float(severity_pred[2]),
|
| 612 |
-
"level_3": float(severity_pred[3])
|
| 613 |
-
},
|
| 614 |
-
"predicted_severity_level": int(np.argmax(severity_pred)),
|
| 615 |
-
"fluency_score": float(fluency_pred),
|
| 616 |
-
"fluency_rating": "High" if fluency_pred > 0.7 else "Medium" if fluency_pred > 0.4 else "Low"
|
| 617 |
-
}
|
| 618 |
-
}
|
| 619 |
-
|
| 620 |
-
return result
|
| 621 |
-
|
| 622 |
-
def predict_batch(self, input_file: str, output_file: str = None) -> List[dict]:
|
| 623 |
-
"""批次預測JSON文件中的所有句子"""
|
| 624 |
-
# 載入輸入文件
|
| 625 |
-
with open(input_file, "r", encoding="utf-8") as f:
|
| 626 |
-
data = json.load(f)
|
| 627 |
-
|
| 628 |
-
sentences = data.get("sentences", [])
|
| 629 |
-
results = []
|
| 630 |
-
|
| 631 |
-
print(f"開始處理 {len(sentences)} 個句子...")
|
| 632 |
-
|
| 633 |
-
for i, sentence in enumerate(sentences):
|
| 634 |
-
print(f"處理第 {i+1}/{len(sentences)} 個句子...")
|
| 635 |
-
result = self.predict_single(sentence)
|
| 636 |
-
results.append(result)
|
| 637 |
-
|
| 638 |
-
# 建立摘要統計
|
| 639 |
-
summary = self._generate_summary(results)
|
| 640 |
-
|
| 641 |
-
final_output = {
|
| 642 |
-
"summary": summary,
|
| 643 |
-
"total_sentences": len(results),
|
| 644 |
-
"predictions": results
|
| 645 |
-
}
|
| 646 |
-
|
| 647 |
-
# 保存結果
|
| 648 |
-
if output_file:
|
| 649 |
-
with open(output_file, "w", encoding="utf-8") as f:
|
| 650 |
-
json.dump(final_output, f, ensure_ascii=False, indent=2)
|
| 651 |
-
print(f"結果已保存到: {output_file}")
|
| 652 |
-
|
| 653 |
-
return final_output
|
| 654 |
-
|
| 655 |
-
def _generate_summary(self, results: List[dict]) -> dict:
|
| 656 |
-
"""生成預測結果摘要"""
|
| 657 |
-
if not results:
|
| 658 |
-
return {}
|
| 659 |
-
|
| 660 |
-
# 統計各類別預測數量
|
| 661 |
-
class_counts = defaultdict(int)
|
| 662 |
-
confidence_scores = []
|
| 663 |
-
fluency_scores = []
|
| 664 |
-
severity_levels = defaultdict(int)
|
| 665 |
-
|
| 666 |
-
for result in results:
|
| 667 |
-
if "error" not in result:
|
| 668 |
-
predicted_class = result["prediction"]["predicted_class"]
|
| 669 |
-
confidence = result["prediction"]["confidence"]
|
| 670 |
-
fluency = result["additional_predictions"]["fluency_score"]
|
| 671 |
-
severity = result["additional_predictions"]["predicted_severity_level"]
|
| 672 |
-
|
| 673 |
-
class_counts[predicted_class] += 1
|
| 674 |
-
confidence_scores.append(confidence)
|
| 675 |
-
fluency_scores.append(fluency)
|
| 676 |
-
severity_levels[severity] += 1
|
| 677 |
-
|
| 678 |
-
# 計算統計數據
|
| 679 |
-
avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
|
| 680 |
-
avg_fluency = np.mean(fluency_scores) if fluency_scores else 0
|
| 681 |
-
|
| 682 |
-
summary = {
|
| 683 |
-
"classification_distribution": dict(class_counts),
|
| 684 |
-
"classification_percentages": {
|
| 685 |
-
k: f"{v/len(results)*100:.1f}%"
|
| 686 |
-
for k, v in class_counts.items()
|
| 687 |
-
},
|
| 688 |
-
"average_confidence": f"{avg_confidence:.3f}",
|
| 689 |
-
"average_fluency_score": f"{avg_fluency:.3f}",
|
| 690 |
-
"severity_distribution": dict(severity_levels),
|
| 691 |
-
"confidence_statistics": {
|
| 692 |
-
"mean": f"{np.mean(confidence_scores):.3f}",
|
| 693 |
-
"std": f"{np.std(confidence_scores):.3f}",
|
| 694 |
-
"min": f"{np.min(confidence_scores):.3f}",
|
| 695 |
-
"max": f"{np.max(confidence_scores):.3f}"
|
| 696 |
-
} if confidence_scores else {},
|
| 697 |
-
"most_common_prediction": max(class_counts.items(), key=lambda x: x[1])[0] if class_counts else "None"
|
| 698 |
-
}
|
| 699 |
-
|
| 700 |
-
return summary
|
| 701 |
-
|
| 702 |
-
def generate_detailed_report(self, results: List[dict], output_dir: str = "./inference_results"):
|
| 703 |
-
"""生成詳細的分析報告"""
|
| 704 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 705 |
-
|
| 706 |
-
# 建立詳細的CSV報告
|
| 707 |
-
report_data = []
|
| 708 |
-
for result in results:
|
| 709 |
-
if "error" not in result:
|
| 710 |
-
row = {
|
| 711 |
-
"sentence_id": result["sentence_id"],
|
| 712 |
-
"predicted_class": result["prediction"]["predicted_class"],
|
| 713 |
-
"confidence": result["prediction"]["confidence"],
|
| 714 |
-
"class_name": result["class_description"]["name"],
|
| 715 |
-
"severity_level": result["additional_predictions"]["predicted_severity_level"],
|
| 716 |
-
"fluency_score": result["additional_predictions"]["fluency_score"],
|
| 717 |
-
"fluency_rating": result["additional_predictions"]["fluency_rating"],
|
| 718 |
-
"input_text": result["input_text"]
|
| 719 |
-
}
|
| 720 |
-
|
| 721 |
-
# 添加各類別機率
|
| 722 |
-
for aphasia_type in self.aphasia_types_mapping.keys():
|
| 723 |
-
row[f"prob_{aphasia_type}"] = result["probability_distribution"][aphasia_type]["probability"]
|
| 724 |
-
|
| 725 |
-
report_data.append(row)
|
| 726 |
-
|
| 727 |
-
# 保存CSV
|
| 728 |
-
if report_data:
|
| 729 |
-
df = pd.DataFrame(report_data)
|
| 730 |
-
df.to_csv(os.path.join(output_dir, "detailed_predictions.csv"), index=False, encoding='utf-8')
|
| 731 |
-
|
| 732 |
-
# 生成統計摘要
|
| 733 |
-
summary_stats = {
|
| 734 |
-
"total_predictions": len(report_data),
|
| 735 |
-
"class_distribution": df["predicted_class"].value_counts().to_dict(),
|
| 736 |
-
"average_confidence": df["confidence"].mean(),
|
| 737 |
-
"confidence_std": df["confidence"].std(),
|
| 738 |
-
"average_fluency": df["fluency_score"].mean(),
|
| 739 |
-
"fluency_std": df["fluency_score"].std(),
|
| 740 |
-
"severity_distribution": df["severity_level"].value_counts().to_dict()
|
| 741 |
-
}
|
| 742 |
-
|
| 743 |
-
with open(os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8") as f:
|
| 744 |
-
json.dump(summary_stats, f, ensure_ascii=False, indent=2)
|
| 745 |
-
|
| 746 |
-
print(f"詳細報告已生成並保存到: {output_dir}")
|
| 747 |
-
return df
|
| 748 |
-
|
| 749 |
-
return None
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
def main():
|
| 753 |
-
"""主程式 - 命令行介面"""
|
| 754 |
-
import argparse
|
| 755 |
-
|
| 756 |
-
parser = argparse.ArgumentParser(description="失語症分類推理系統")
|
| 757 |
-
parser.add_argument("--model_dir", type=str, default = '/workspace/SH001/adaptive_aphasia_model',
|
| 758 |
-
help="訓練好的模型目錄路徑")
|
| 759 |
-
parser.add_argument("--input_file", type=str, default = '/workspace/SH001/website/sample.input.json',
|
| 760 |
-
help="輸入JSON文件路徑")
|
| 761 |
-
parser.add_argument("--output_file", type=str, default="./aphasia_predictions.json",
|
| 762 |
-
help="輸出JSON文件路徑")
|
| 763 |
-
parser.add_argument("--report_dir", type=str, default="./inference_results",
|
| 764 |
-
help="詳細報告輸出目錄")
|
| 765 |
-
parser.add_argument("--generate_report", action="store_true",
|
| 766 |
-
help="是否生成詳細的CSV報告")
|
| 767 |
-
|
| 768 |
-
args = parser.parse_args()
|
| 769 |
-
|
| 770 |
-
try:
|
| 771 |
-
# 初始化推理系統
|
| 772 |
-
print("正在初始化推理系統...")
|
| 773 |
-
inference_system = AphasiaInferenceSystem(args.model_dir)
|
| 774 |
-
|
| 775 |
-
# 執行批次預測
|
| 776 |
-
print("開始執行批次預測...")
|
| 777 |
-
results = inference_system.predict_batch(args.input_file, args.output_file)
|
| 778 |
-
|
| 779 |
-
# 生成詳細報告
|
| 780 |
-
if args.generate_report:
|
| 781 |
-
print("生成詳細報告...")
|
| 782 |
-
inference_system.generate_detailed_report(results["predictions"], args.report_dir)
|
| 783 |
-
|
| 784 |
-
# 顯示摘要
|
| 785 |
-
print("\n=== 預測摘要 ===")
|
| 786 |
-
summary = results["summary"]
|
| 787 |
-
print(f"總句子數: {results['total_sentences']}")
|
| 788 |
-
print(f"平均信心度: {summary.get('average_confidence', 'N/A')}")
|
| 789 |
-
print(f"平均流利度: {summary.get('average_fluency_score', 'N/A')}")
|
| 790 |
-
print(f"最常見預測: {summary.get('most_common_prediction', 'N/A')}")
|
| 791 |
-
|
| 792 |
-
print("\n類別分佈:")
|
| 793 |
-
for class_name, count in summary.get("classification_distribution", {}).items():
|
| 794 |
-
percentage = summary.get("classification_percentages", {}).get(class_name, "0%")
|
| 795 |
-
print(f" {class_name}: {count} ({percentage})")
|
| 796 |
-
|
| 797 |
-
print(f"\n結果已保存到: {args.output_file}")
|
| 798 |
-
|
| 799 |
-
except Exception as e:
|
| 800 |
-
print(f"錯誤: {str(e)}")
|
| 801 |
-
import traceback
|
| 802 |
-
traceback.print_exc()
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
# 使用範例
|
| 806 |
-
def example_usage():
|
| 807 |
-
"""使用範例"""
|
| 808 |
-
|
| 809 |
-
# 1. 基本使用
|
| 810 |
-
print("=== 失語症分類推理系統使用範例 ===\n")
|
| 811 |
-
|
| 812 |
-
# 範例輸入數據
|
| 813 |
-
sample_input = {
|
| 814 |
-
"sentences": [
|
| 815 |
-
{
|
| 816 |
-
"sentence_id": "S1",
|
| 817 |
-
"aphasia_type": "BROCA", # 這在推理時會被忽略
|
| 818 |
-
"dialogues": [
|
| 819 |
-
{
|
| 820 |
-
"INV": [
|
| 821 |
-
{
|
| 822 |
-
"tokens": ["how", "are", "you", "feeling"],
|
| 823 |
-
"word_pos_ids": [9, 10, 5, 6],
|
| 824 |
-
"word_grammar_ids": [[1, 4, 11], [2, 4, 2], [3, 4, 1], [4, 0, 3]],
|
| 825 |
-
"word_durations": [["how", 300], ["are", 200], ["you", 150], ["feeling", 500]]
|
| 826 |
-
}
|
| 827 |
-
],
|
| 828 |
-
"PAR": [
|
| 829 |
-
{
|
| 830 |
-
"tokens": ["I", "feel", "good"],
|
| 831 |
-
"word_pos_ids": [1, 6, 8],
|
| 832 |
-
"word_grammar_ids": [[1, 2, 1], [2, 3, 2], [3, 4, 8]],
|
| 833 |
-
"word_durations": [["I", 200], ["feel", 400], ["good", 600]]
|
| 834 |
-
}
|
| 835 |
-
]
|
| 836 |
-
}
|
| 837 |
-
]
|
| 838 |
-
}
|
| 839 |
-
]
|
| 840 |
-
}
|
| 841 |
-
|
| 842 |
-
# 保存範例輸入
|
| 843 |
-
with open("sample_input.json", "w", encoding="utf-8") as f:
|
| 844 |
-
json.dump(sample_input, f, ensure_ascii=False, indent=2)
|
| 845 |
-
|
| 846 |
-
print("範例輸入文件已創建: sample_input.json")
|
| 847 |
-
|
| 848 |
-
# 顯示使用說明
|
| 849 |
-
usage_instructions = """
|
| 850 |
-
使用方法:
|
| 851 |
-
|
| 852 |
-
1. 命令行使用:
|
| 853 |
-
python aphasia_inference.py \\
|
| 854 |
-
--model_dir ./adaptive_aphasia_model \\
|
| 855 |
-
--input_file sample_input.json \\
|
| 856 |
-
--output_file predictions.json \\
|
| 857 |
-
--generate_report \\
|
| 858 |
-
--report_dir ./results
|
| 859 |
-
|
| 860 |
-
2. Python代碼使用:
|
| 861 |
-
from aphasia_inference import AphasiaInferenceSystem
|
| 862 |
-
|
| 863 |
-
# 初始化系統
|
| 864 |
-
system = AphasiaInferenceSystem("./adaptive_aphasia_model")
|
| 865 |
-
|
| 866 |
-
# 單個預測
|
| 867 |
-
with open("sample_input.json", "r") as f:
|
| 868 |
-
data = json.load(f)
|
| 869 |
-
result = system.predict_single(data["sentences"][0])
|
| 870 |
-
|
| 871 |
-
# 批次預測
|
| 872 |
-
results = system.predict_batch("sample_input.json", "output.json")
|
| 873 |
-
|
| 874 |
-
3. 輸出格式:
|
| 875 |
-
- JSON格式包含詳細的預測結果和機率分佈
|
| 876 |
-
- CSV格式包含表格化的預測數據
|
| 877 |
-
- 統計摘要包含整體分析結果
|
| 878 |
-
|
| 879 |
-
4. 支援的失語症類型:
|
| 880 |
-
- BROCA: 布若卡失語症
|
| 881 |
-
- WERNICKE: 韋尼克失語症
|
| 882 |
-
- ANOMIC: 命名性失語症
|
| 883 |
-
- CONDUCTION: 傳導性失語症
|
| 884 |
-
- GLOBAL: 全面性失語症
|
| 885 |
-
- 以及其他類型...
|
| 886 |
-
"""
|
| 887 |
-
|
| 888 |
-
print(usage_instructions)
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
if __name__ == "__main__":
|
| 892 |
-
# 如果作為腳本執行,運行主程式
|
| 893 |
-
main()
|
| 894 |
-
|
| 895 |
-
# 如果想看使用範例,取消下面這行的註釋
|
| 896 |
-
# example_usage()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output.py
ADDED
|
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Aphasia classification inference (cleaned).
|
| 4 |
+
- Respects model_dir argument
|
| 5 |
+
- Correctly parses durations like ["word", 300] and [start, end]
|
| 6 |
+
- Removes duplicate load_state_dict
|
| 7 |
+
- Adds predict_from_chajson(json_path, ...) helper
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import math
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Dict, List, Optional, Tuple
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import pandas as pd
|
| 22 |
+
from transformers import AutoTokenizer, AutoModel
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# =========================
|
| 26 |
+
# Model definition (unchanged shape)
|
| 27 |
+
# =========================
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ModelConfig:
|
| 31 |
+
model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
| 32 |
+
max_length: int = 512
|
| 33 |
+
hidden_size: int = 768
|
| 34 |
+
pos_vocab_size: int = 150
|
| 35 |
+
pos_emb_dim: int = 64
|
| 36 |
+
grammar_dim: int = 3
|
| 37 |
+
grammar_hidden_dim: int = 64
|
| 38 |
+
duration_hidden_dim: int = 128
|
| 39 |
+
prosody_dim: int = 32
|
| 40 |
+
num_attention_heads: int = 8
|
| 41 |
+
attention_dropout: float = 0.3
|
| 42 |
+
classifier_hidden_dims: List[int] = None
|
| 43 |
+
dropout_rate: float = 0.3
|
| 44 |
+
def __post_init__(self):
|
| 45 |
+
if self.classifier_hidden_dims is None:
|
| 46 |
+
self.classifier_hidden_dims = [512, 256]
|
| 47 |
+
|
| 48 |
+
class StablePositionalEncoding(nn.Module):
|
| 49 |
+
def __init__(self, d_model: int, max_len: int = 5000):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.d_model = d_model
|
| 52 |
+
pe = torch.zeros(max_len, d_model)
|
| 53 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 54 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 55 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 56 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 57 |
+
self.register_buffer('pe', pe.unsqueeze(0))
|
| 58 |
+
self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
seq_len = x.size(1)
|
| 61 |
+
sinusoidal = self.pe[:, :seq_len, :].to(x.device)
|
| 62 |
+
learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
|
| 63 |
+
return x + 0.1 * (sinusoidal + learnable)
|
| 64 |
+
|
| 65 |
+
class StableMultiHeadAttention(nn.Module):
|
| 66 |
+
def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.num_heads = num_heads
|
| 69 |
+
self.feature_dim = feature_dim
|
| 70 |
+
self.head_dim = feature_dim // num_heads
|
| 71 |
+
assert feature_dim % num_heads == 0
|
| 72 |
+
self.query = nn.Linear(feature_dim, feature_dim)
|
| 73 |
+
self.key = nn.Linear(feature_dim, feature_dim)
|
| 74 |
+
self.value = nn.Linear(feature_dim, feature_dim)
|
| 75 |
+
self.dropout = nn.Dropout(dropout)
|
| 76 |
+
self.output_proj = nn.Linear(feature_dim, feature_dim)
|
| 77 |
+
self.layer_norm = nn.LayerNorm(feature_dim)
|
| 78 |
+
def forward(self, x, mask=None):
|
| 79 |
+
b, t, _ = x.size()
|
| 80 |
+
Q = self.query(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
| 81 |
+
K = self.key(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
| 82 |
+
V = self.value(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
| 83 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 84 |
+
if mask is not None:
|
| 85 |
+
if mask.dim() == 2:
|
| 86 |
+
mask = mask.unsqueeze(1).unsqueeze(1)
|
| 87 |
+
scores.masked_fill_(mask == 0, -1e9)
|
| 88 |
+
attn = F.softmax(scores, dim=-1)
|
| 89 |
+
attn = self.dropout(attn)
|
| 90 |
+
ctx = torch.matmul(attn, V)
|
| 91 |
+
ctx = ctx.transpose(1, 2).contiguous().view(b, t, self.feature_dim)
|
| 92 |
+
out = self.output_proj(ctx)
|
| 93 |
+
return self.layer_norm(out + x)
|
| 94 |
+
|
| 95 |
+
class StableLinguisticFeatureExtractor(nn.Module):
|
| 96 |
+
def __init__(self, config: ModelConfig):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.config = config
|
| 99 |
+
self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0)
|
| 100 |
+
self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4)
|
| 101 |
+
self.grammar_projection = nn.Sequential(
|
| 102 |
+
nn.Linear(config.grammar_dim, config.grammar_hidden_dim),
|
| 103 |
+
nn.Tanh(),
|
| 104 |
+
nn.LayerNorm(config.grammar_hidden_dim),
|
| 105 |
+
nn.Dropout(config.dropout_rate * 0.3)
|
| 106 |
+
)
|
| 107 |
+
self.duration_projection = nn.Sequential(
|
| 108 |
+
nn.Linear(1, config.duration_hidden_dim),
|
| 109 |
+
nn.Tanh(),
|
| 110 |
+
nn.LayerNorm(config.duration_hidden_dim)
|
| 111 |
+
)
|
| 112 |
+
self.prosody_projection = nn.Sequential(
|
| 113 |
+
nn.Linear(config.prosody_dim, config.prosody_dim),
|
| 114 |
+
nn.ReLU(),
|
| 115 |
+
nn.LayerNorm(config.prosody_dim)
|
| 116 |
+
)
|
| 117 |
+
total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
|
| 118 |
+
config.duration_hidden_dim + config.prosody_dim)
|
| 119 |
+
self.feature_fusion = nn.Sequential(
|
| 120 |
+
nn.Linear(total_feature_dim, total_feature_dim // 2),
|
| 121 |
+
nn.Tanh(),
|
| 122 |
+
nn.LayerNorm(total_feature_dim // 2),
|
| 123 |
+
nn.Dropout(config.dropout_rate)
|
| 124 |
+
)
|
| 125 |
+
def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
|
| 126 |
+
b, t = pos_ids.size()
|
| 127 |
+
pos_ids = pos_ids.clamp(0, self.config.pos_vocab_size - 1)
|
| 128 |
+
pos_emb = self.pos_embedding(pos_ids)
|
| 129 |
+
pos_feat = self.pos_attention(pos_emb, attention_mask)
|
| 130 |
+
gra_feat = self.grammar_projection(grammar_ids.float())
|
| 131 |
+
dur_feat = self.duration_projection(durations.unsqueeze(-1).float())
|
| 132 |
+
pro_feat = self.prosody_projection(prosody_features.float())
|
| 133 |
+
combined = torch.cat([pos_feat, gra_feat, dur_feat, pro_feat], dim=-1)
|
| 134 |
+
fused = self.feature_fusion(combined)
|
| 135 |
+
mask_exp = attention_mask.unsqueeze(-1).float()
|
| 136 |
+
pooled = torch.sum(fused * mask_exp, dim=1) / torch.sum(mask_exp, dim=1)
|
| 137 |
+
return pooled
|
| 138 |
+
|
| 139 |
+
class StableAphasiaClassifier(nn.Module):
|
| 140 |
+
def __init__(self, config: ModelConfig, num_labels: int):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.config = config
|
| 143 |
+
self.num_labels = num_labels
|
| 144 |
+
self.bert = AutoModel.from_pretrained(config.model_name)
|
| 145 |
+
self.bert_config = self.bert.config
|
| 146 |
+
self.positional_encoder = StablePositionalEncoding(d_model=self.bert_config.hidden_size,
|
| 147 |
+
max_len=config.max_length)
|
| 148 |
+
self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
|
| 149 |
+
bert_dim = self.bert_config.hidden_size
|
| 150 |
+
lingu_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
|
| 151 |
+
config.duration_hidden_dim + config.prosody_dim) // 2
|
| 152 |
+
self.feature_fusion = nn.Sequential(
|
| 153 |
+
nn.Linear(bert_dim + lingu_dim, bert_dim),
|
| 154 |
+
nn.LayerNorm(bert_dim),
|
| 155 |
+
nn.Tanh(),
|
| 156 |
+
nn.Dropout(config.dropout_rate)
|
| 157 |
+
)
|
| 158 |
+
self.classifier = self._build_classifier(bert_dim, num_labels)
|
| 159 |
+
self.severity_head = nn.Sequential(nn.Linear(bert_dim, 4), nn.Softmax(dim=-1))
|
| 160 |
+
self.fluency_head = nn.Sequential(nn.Linear(bert_dim, 1), nn.Sigmoid())
|
| 161 |
+
def _build_classifier(self, input_dim: int, num_labels: int):
|
| 162 |
+
layers, cur = [], input_dim
|
| 163 |
+
for h in self.config.classifier_hidden_dims:
|
| 164 |
+
layers += [nn.Linear(cur, h), nn.LayerNorm(h), nn.Tanh(), nn.Dropout(self.config.dropout_rate)]
|
| 165 |
+
cur = h
|
| 166 |
+
layers.append(nn.Linear(cur, num_labels))
|
| 167 |
+
return nn.Sequential(*layers)
|
| 168 |
+
def _attention_pooling(self, seq_out, attn_mask):
|
| 169 |
+
attn_w = torch.softmax(torch.sum(seq_out, dim=-1, keepdim=True), dim=1)
|
| 170 |
+
attn_w = attn_w * attn_mask.unsqueeze(-1).float()
|
| 171 |
+
attn_w = attn_w / (torch.sum(attn_w, dim=1, keepdim=True) + 1e-9)
|
| 172 |
+
return torch.sum(seq_out * attn_w, dim=1)
|
| 173 |
+
def forward(self, input_ids, attention_mask, labels=None,
|
| 174 |
+
word_pos_ids=None, word_grammar_ids=None, word_durations=None,
|
| 175 |
+
prosody_features=None, **kwargs):
|
| 176 |
+
bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 177 |
+
seq_out = bert_out.last_hidden_state
|
| 178 |
+
pos_enh = self.positional_encoder(seq_out)
|
| 179 |
+
pooled = self._attention_pooling(pos_enh, attention_mask)
|
| 180 |
+
if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
|
| 181 |
+
if prosody_features is None:
|
| 182 |
+
b, t = input_ids.size()
|
| 183 |
+
prosody_features = torch.zeros(b, t, self.config.prosody_dim, device=input_ids.device)
|
| 184 |
+
ling = self.linguistic_extractor(word_pos_ids, word_grammar_ids, word_durations,
|
| 185 |
+
prosody_features, attention_mask)
|
| 186 |
+
else:
|
| 187 |
+
ling = torch.zeros(input_ids.size(0),
|
| 188 |
+
(self.config.pos_emb_dim + self.config.grammar_hidden_dim +
|
| 189 |
+
self.config.duration_hidden_dim + self.config.prosody_dim) // 2,
|
| 190 |
+
device=input_ids.device)
|
| 191 |
+
fused = self.feature_fusion(torch.cat([pooled, ling], dim=1))
|
| 192 |
+
logits = self.classifier(fused)
|
| 193 |
+
severity_pred = self.severity_head(fused)
|
| 194 |
+
fluency_pred = self.fluency_head(fused)
|
| 195 |
+
return {"logits": logits, "severity_pred": severity_pred, "fluency_pred": fluency_pred, "loss": None}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# =========================
|
| 199 |
+
# Inference system (fixed wiring)
|
| 200 |
+
# =========================
|
| 201 |
+
|
| 202 |
+
class AphasiaInferenceSystem:
|
| 203 |
+
"""失語症分類推理系統"""
|
| 204 |
+
|
| 205 |
+
def __init__(self, model_dir: str):
|
| 206 |
+
self.model_dir = model_dir # <— honor the argument
|
| 207 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 208 |
+
|
| 209 |
+
# Descriptions (unchanged)
|
| 210 |
+
self.aphasia_descriptions = {
|
| 211 |
+
"BROCA": {"name": "Broca's Aphasia (Non-fluent)", "description":
|
| 212 |
+
"Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.",
|
| 213 |
+
"features": ["Non-fluent speech", "Preserved comprehension", "Grammar difficulties", "Word-finding problems"]},
|
| 214 |
+
"TRANSMOTOR": {"name": "Trans-cortical Motor Aphasia", "description":
|
| 215 |
+
"Similar to Broca's aphasia but with preserved repetition abilities. Speech is non-fluent with good comprehension.",
|
| 216 |
+
"features": ["Non-fluent speech", "Good repetition", "Preserved comprehension", "Grammar difficulties"]},
|
| 217 |
+
"NOTAPHASICBYWAB": {"name": "Not Aphasic by WAB", "description":
|
| 218 |
+
"Individuals who do not meet the criteria for aphasia according to the Western Aphasia Battery assessment.",
|
| 219 |
+
"features": ["Normal language function", "No significant language impairment", "Good comprehension", "Fluent speech"]},
|
| 220 |
+
"CONDUCTION": {"name": "Conduction Aphasia", "description":
|
| 221 |
+
"Characterized by fluent speech with good comprehension but severely impaired repetition. Often involves phonemic paraphasias.",
|
| 222 |
+
"features": ["Fluent speech", "Good comprehension", "Poor repetition", "Phonemic errors"]},
|
| 223 |
+
"WERNICKE": {"name": "Wernicke's Aphasia (Fluent)", "description":
|
| 224 |
+
"Fluent but often meaningless speech with poor comprehension. Speech may contain neologisms and jargon.",
|
| 225 |
+
"features": ["Fluent speech", "Poor comprehension", "Jargon speech", "Neologisms"]},
|
| 226 |
+
"ANOMIC": {"name": "Anomic Aphasia", "description":
|
| 227 |
+
"Primarily characterized by word-finding difficulties with otherwise relatively preserved language abilities.",
|
| 228 |
+
"features": ["Word-finding difficulties", "Good comprehension", "Fluent speech", "Circumlocution"]},
|
| 229 |
+
"GLOBAL": {"name": "Global Aphasia", "description":
|
| 230 |
+
"Severe impairment in all language modalities - comprehension, production, repetition, and naming.",
|
| 231 |
+
"features": ["Severe comprehension deficit", "Non-fluent speech", "Poor repetition", "Severe naming difficulties"]},
|
| 232 |
+
"ISOLATION": {"name": "Isolation Syndrome", "description":
|
| 233 |
+
"Rare condition with preserved repetition but severely impaired comprehension and spontaneous speech.",
|
| 234 |
+
"features": ["Good repetition", "Poor comprehension", "Limited spontaneous speech", "Echolalia"]},
|
| 235 |
+
"TRANSSENSORY": {"name": "Trans-cortical Sensory Aphasia", "description":
|
| 236 |
+
"Fluent speech with good repetition but impaired comprehension, similar to Wernicke's but with preserved repetition.",
|
| 237 |
+
"features": ["Fluent speech", "Good repetition", "Poor comprehension", "Semantic errors"]}
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
self.load_configuration()
|
| 241 |
+
self.load_model()
|
| 242 |
+
print(f"推理系統初始化完成,使用設備: {self.device}")
|
| 243 |
+
|
| 244 |
+
def load_configuration(self):
|
| 245 |
+
cfg_path = os.path.join(self.model_dir, "config.json")
|
| 246 |
+
if os.path.exists(cfg_path):
|
| 247 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 248 |
+
cfg = json.load(f)
|
| 249 |
+
self.aphasia_types_mapping = cfg.get("aphasia_types_mapping", {
|
| 250 |
+
"BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
|
| 251 |
+
"CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
|
| 252 |
+
"GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
|
| 253 |
+
})
|
| 254 |
+
self.num_labels = cfg.get("num_labels", 9)
|
| 255 |
+
self.model_name = cfg.get("model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
|
| 256 |
+
else:
|
| 257 |
+
self.aphasia_types_mapping = {
|
| 258 |
+
"BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
|
| 259 |
+
"CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
|
| 260 |
+
"GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
|
| 261 |
+
}
|
| 262 |
+
self.num_labels = 9
|
| 263 |
+
self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
| 264 |
+
self.id_to_aphasia_type = {v: k for k, v in self.aphasia_types_mapping.items()}
|
| 265 |
+
|
| 266 |
+
def load_model(self):
|
| 267 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, use_fast=True)
|
| 268 |
+
# pad token fix
|
| 269 |
+
if self.tokenizer.pad_token is None:
|
| 270 |
+
if self.tokenizer.eos_token is not None:
|
| 271 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 272 |
+
elif self.tokenizer.unk_token is not None:
|
| 273 |
+
self.tokenizer.pad_token = self.tokenizer.unk_token
|
| 274 |
+
else:
|
| 275 |
+
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 276 |
+
# optional added tokens
|
| 277 |
+
add_path = os.path.join(self.model_dir, "added_tokens.json")
|
| 278 |
+
if os.path.exists(add_path):
|
| 279 |
+
with open(add_path, "r", encoding="utf-8") as f:
|
| 280 |
+
data = json.load(f)
|
| 281 |
+
tokens = list(data.keys()) if isinstance(data, dict) else data
|
| 282 |
+
if tokens:
|
| 283 |
+
self.tokenizer.add_tokens(tokens)
|
| 284 |
+
|
| 285 |
+
self.config = ModelConfig()
|
| 286 |
+
self.config.model_name = self.model_name
|
| 287 |
+
|
| 288 |
+
self.model = StableAphasiaClassifier(self.config, self.num_labels)
|
| 289 |
+
self.model.bert.resize_token_embeddings(len(self.tokenizer))
|
| 290 |
+
|
| 291 |
+
model_path = os.path.join(self.model_dir, "pytorch_model.bin")
|
| 292 |
+
if not os.path.exists(model_path):
|
| 293 |
+
raise FileNotFoundError(f"模型權重文件不存在: {model_path}")
|
| 294 |
+
state = torch.load(model_path, map_location=self.device)
|
| 295 |
+
self.model.load_state_dict(state) # (once)
|
| 296 |
+
|
| 297 |
+
self.model.to(self.device)
|
| 298 |
+
self.model.eval()
|
| 299 |
+
|
| 300 |
+
# ---------- helpers ----------
|
| 301 |
+
|
| 302 |
+
def _dur_to_float(self, d) -> float:
|
| 303 |
+
"""Robustly parse duration from various shapes:
|
| 304 |
+
- number
|
| 305 |
+
- ["word", ms]
|
| 306 |
+
- [start, end]
|
| 307 |
+
- {"dur": ms} (future-proof)
|
| 308 |
+
"""
|
| 309 |
+
if isinstance(d, (int, float)):
|
| 310 |
+
return float(d)
|
| 311 |
+
if isinstance(d, list):
|
| 312 |
+
if len(d) == 2:
|
| 313 |
+
# ["word", 300] or [start, end]
|
| 314 |
+
a, b = d[0], d[1]
|
| 315 |
+
# case 1: word + ms
|
| 316 |
+
if isinstance(a, str) and isinstance(b, (int, float)):
|
| 317 |
+
return float(b)
|
| 318 |
+
# case 2: start, end
|
| 319 |
+
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
|
| 320 |
+
return float(b) - float(a)
|
| 321 |
+
if isinstance(d, dict):
|
| 322 |
+
for k in ("dur", "duration", "ms"):
|
| 323 |
+
if k in d and isinstance(d[k], (int, float)):
|
| 324 |
+
return float(d[k])
|
| 325 |
+
return 0.0
|
| 326 |
+
|
| 327 |
+
def _extract_prosodic_features(self, durations, tokens):
|
| 328 |
+
vals = []
|
| 329 |
+
for d in durations:
|
| 330 |
+
vals.append(self._dur_to_float(d))
|
| 331 |
+
vals = [v for v in vals if v > 0]
|
| 332 |
+
if not vals:
|
| 333 |
+
return [0.0] * self.config.prosody_dim
|
| 334 |
+
features = [
|
| 335 |
+
float(np.mean(vals)),
|
| 336 |
+
float(np.std(vals)),
|
| 337 |
+
float(np.median(vals)),
|
| 338 |
+
float(len([v for v in vals if v > (np.mean(vals) * 1.5)])),
|
| 339 |
+
]
|
| 340 |
+
while len(features) < self.config.prosody_dim:
|
| 341 |
+
features.append(0.0)
|
| 342 |
+
return features[:self.config.prosody_dim]
|
| 343 |
+
|
| 344 |
+
def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded):
|
| 345 |
+
# map subtoken -> original token index
|
| 346 |
+
subtoken_to_token = []
|
| 347 |
+
for idx, tok in enumerate(tokens):
|
| 348 |
+
subtoks = self.tokenizer.tokenize(tok)
|
| 349 |
+
subtoken_to_token.extend([idx] * max(1, len(subtoks)))
|
| 350 |
+
|
| 351 |
+
aligned_pos = [0] # [CLS]
|
| 352 |
+
aligned_grammar = [[0, 0, 0]] # [CLS]
|
| 353 |
+
aligned_durations = [0.0] # [CLS]
|
| 354 |
+
|
| 355 |
+
# reserve last slot for [SEP]
|
| 356 |
+
max_body = self.config.max_length - 2
|
| 357 |
+
for st_idx in range(max_body):
|
| 358 |
+
if st_idx < len(subtoken_to_token):
|
| 359 |
+
orig = subtoken_to_token[st_idx]
|
| 360 |
+
aligned_pos.append(pos_ids[orig] if orig < len(pos_ids) else 0)
|
| 361 |
+
aligned_grammar.append(grammar_ids[orig] if orig < len(grammar_ids) else [0, 0, 0])
|
| 362 |
+
aligned_durations.append(self._dur_to_float(durations[orig]) if orig < len(durations) else 0.0)
|
| 363 |
+
else:
|
| 364 |
+
aligned_pos.append(0)
|
| 365 |
+
aligned_grammar.append([0, 0, 0])
|
| 366 |
+
aligned_durations.append(0.0)
|
| 367 |
+
|
| 368 |
+
aligned_pos.append(0) # [SEP]
|
| 369 |
+
aligned_grammar.append([0, 0, 0]) # [SEP]
|
| 370 |
+
aligned_durations.append(0.0) # [SEP]
|
| 371 |
+
return aligned_pos, aligned_grammar, aligned_durations
|
| 372 |
+
|
| 373 |
+
def preprocess_sentence(self, sentence_data: dict) -> Optional[dict]:
|
| 374 |
+
all_tokens, all_pos, all_grammar, all_durations = [], [], [], []
|
| 375 |
+
for d_idx, dialogue in enumerate(sentence_data.get("dialogues", [])):
|
| 376 |
+
if d_idx > 0:
|
| 377 |
+
all_tokens.append("[DIALOGUE]")
|
| 378 |
+
all_pos.append(0)
|
| 379 |
+
all_grammar.append([0, 0, 0])
|
| 380 |
+
all_durations.append(0.0)
|
| 381 |
+
for par in dialogue.get("PAR", []):
|
| 382 |
+
if "tokens" in par and par["tokens"]:
|
| 383 |
+
toks = par["tokens"]
|
| 384 |
+
pos_ids = par.get("word_pos_ids", [0] * len(toks))
|
| 385 |
+
gra_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(toks))
|
| 386 |
+
durs = par.get("word_durations", [0.0] * len(toks))
|
| 387 |
+
all_tokens.extend(toks)
|
| 388 |
+
all_pos.extend(pos_ids)
|
| 389 |
+
all_grammar.extend(gra_ids)
|
| 390 |
+
all_durations.extend(durs)
|
| 391 |
+
if not all_tokens:
|
| 392 |
+
return None
|
| 393 |
+
|
| 394 |
+
text = " ".join(all_tokens)
|
| 395 |
+
enc = self.tokenizer(text, max_length=self.config.max_length, padding="max_length",
|
| 396 |
+
truncation=True, return_tensors="pt")
|
| 397 |
+
aligned_pos, aligned_gra, aligned_dur = self._align_features(
|
| 398 |
+
all_tokens, all_pos, all_grammar, all_durations, enc
|
| 399 |
+
)
|
| 400 |
+
prosody = self._extract_prosodic_features(all_durations, all_tokens)
|
| 401 |
+
prosody_tensor = torch.tensor(prosody).unsqueeze(0).repeat(self.config.max_length, 1)
|
| 402 |
+
|
| 403 |
+
return {
|
| 404 |
+
"input_ids": enc["input_ids"].squeeze(0),
|
| 405 |
+
"attention_mask": enc["attention_mask"].squeeze(0),
|
| 406 |
+
"word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long),
|
| 407 |
+
"word_grammar_ids": torch.tensor(aligned_gra, dtype=torch.long),
|
| 408 |
+
"word_durations": torch.tensor(aligned_dur, dtype=torch.float),
|
| 409 |
+
"prosody_features": prosody_tensor.float(),
|
| 410 |
+
"sentence_id": sentence_data.get("sentence_id", "unknown"),
|
| 411 |
+
"original_tokens": all_tokens,
|
| 412 |
+
"text": text
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
def predict_single(self, sentence_data: dict) -> dict:
|
| 416 |
+
proc = self.preprocess_sentence(sentence_data)
|
| 417 |
+
if proc is None:
|
| 418 |
+
return {"error": "無法處理輸入數據", "sentence_id": sentence_data.get("sentence_id", "unknown")}
|
| 419 |
+
inp = {
|
| 420 |
+
"input_ids": proc["input_ids"].unsqueeze(0).to(self.device),
|
| 421 |
+
"attention_mask": proc["attention_mask"].unsqueeze(0).to(self.device),
|
| 422 |
+
"word_pos_ids": proc["word_pos_ids"].unsqueeze(0).to(self.device),
|
| 423 |
+
"word_grammar_ids": proc["word_grammar_ids"].unsqueeze(0).to(self.device),
|
| 424 |
+
"word_durations": proc["word_durations"].unsqueeze(0).to(self.device),
|
| 425 |
+
"prosody_features": proc["prosody_features"].unsqueeze(0).to(self.device),
|
| 426 |
+
}
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
out = self.model(**inp)
|
| 429 |
+
logits = out["logits"]
|
| 430 |
+
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
| 431 |
+
pred_id = int(np.argmax(probs))
|
| 432 |
+
sev = out["severity_pred"].cpu().numpy()[0]
|
| 433 |
+
flu = float(out["fluency_pred"].cpu().numpy()[0][0])
|
| 434 |
+
|
| 435 |
+
pred_type = self.id_to_aphasia_type[pred_id]
|
| 436 |
+
conf = float(probs[pred_id])
|
| 437 |
+
|
| 438 |
+
dist = {}
|
| 439 |
+
for a_type, t_id in self.aphasia_types_mapping.items():
|
| 440 |
+
dist[a_type] = {"probability": float(probs[t_id]), "percentage": f"{probs[t_id]*100:.2f}%"}
|
| 441 |
+
|
| 442 |
+
sorted_dist = dict(sorted(dist.items(), key=lambda x: x[1]["probability"], reverse=True))
|
| 443 |
+
return {
|
| 444 |
+
"sentence_id": proc["sentence_id"],
|
| 445 |
+
"input_text": proc["text"],
|
| 446 |
+
"original_tokens": proc["original_tokens"],
|
| 447 |
+
"prediction": {
|
| 448 |
+
"predicted_class": pred_type,
|
| 449 |
+
"confidence": conf,
|
| 450 |
+
"confidence_percentage": f"{conf*100:.2f}%"
|
| 451 |
+
},
|
| 452 |
+
"class_description": self.aphasia_descriptions.get(pred_type, {
|
| 453 |
+
"name": pred_type, "description": "Description not available", "features": []
|
| 454 |
+
}),
|
| 455 |
+
"probability_distribution": sorted_dist,
|
| 456 |
+
"additional_predictions": {
|
| 457 |
+
"severity_distribution": {
|
| 458 |
+
"level_0": float(sev[0]), "level_1": float(sev[1]),
|
| 459 |
+
"level_2": float(sev[2]), "level_3": float(sev[3])
|
| 460 |
+
},
|
| 461 |
+
"predicted_severity_level": int(np.argmax(sev)),
|
| 462 |
+
"fluency_score": flu,
|
| 463 |
+
"fluency_rating": "High" if flu > 0.7 else ("Medium" if flu > 0.4 else "Low"),
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
def predict_batch(self, input_file: str, output_file: Optional[str] = None) -> Dict:
|
| 468 |
+
with open(input_file, "r", encoding="utf-8") as f:
|
| 469 |
+
data = json.load(f)
|
| 470 |
+
sentences = data.get("sentences", [])
|
| 471 |
+
results = []
|
| 472 |
+
print(f"開始處理 {len(sentences)} 個句子...")
|
| 473 |
+
for i, s in enumerate(sentences):
|
| 474 |
+
print(f"處理第 {i+1}/{len(sentences)} 個句子...")
|
| 475 |
+
results.append(self.predict_single(s))
|
| 476 |
+
summary = self._generate_summary(results)
|
| 477 |
+
final = {"summary": summary, "total_sentences": len(results), "predictions": results}
|
| 478 |
+
if output_file:
|
| 479 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
| 480 |
+
json.dump(final, f, ensure_ascii=False, indent=2)
|
| 481 |
+
print(f"結果已保存到: {output_file}")
|
| 482 |
+
return final
|
| 483 |
+
|
| 484 |
+
def _generate_summary(self, results: List[dict]) -> dict:
|
| 485 |
+
if not results:
|
| 486 |
+
return {}
|
| 487 |
+
class_counts = defaultdict(int)
|
| 488 |
+
confs, flus = [], []
|
| 489 |
+
sev_counts = defaultdict(int)
|
| 490 |
+
for r in results:
|
| 491 |
+
if "error" in r:
|
| 492 |
+
continue
|
| 493 |
+
c = r["prediction"]["predicted_class"]
|
| 494 |
+
class_counts[c] += 1
|
| 495 |
+
confs.append(r["prediction"]["confidence"])
|
| 496 |
+
flus.append(r["additional_predictions"]["fluency_score"])
|
| 497 |
+
sev_counts[r["additional_predictions"]["predicted_severity_level"]] += 1
|
| 498 |
+
avg_conf = float(np.mean(confs)) if confs else 0.0
|
| 499 |
+
avg_flu = float(np.mean(flus)) if flus else 0.0
|
| 500 |
+
return {
|
| 501 |
+
"classification_distribution": dict(class_counts),
|
| 502 |
+
"classification_percentages": {k: f"{v/len(results)*100:.1f}%" for k, v in class_counts.items()},
|
| 503 |
+
"average_confidence": f"{avg_conf:.3f}",
|
| 504 |
+
"average_fluency_score": f"{avg_flu:.3f}",
|
| 505 |
+
"severity_distribution": dict(sev_counts),
|
| 506 |
+
"confidence_statistics": {} if not confs else {
|
| 507 |
+
"mean": f"{np.mean(confs):.3f}",
|
| 508 |
+
"std": f"{np.std(confs):.3f}",
|
| 509 |
+
"min": f"{np.min(confs):.3f}",
|
| 510 |
+
"max": f"{np.max(confs):.3f}",
|
| 511 |
+
},
|
| 512 |
+
"most_common_prediction": max(class_counts.items(), key=lambda x: x[1])[0] if class_counts else "None",
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
def generate_detailed_report(self, results: List[dict], output_dir: str = "./inference_results"):
|
| 516 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 517 |
+
rows = []
|
| 518 |
+
for r in results:
|
| 519 |
+
if "error" in r:
|
| 520 |
+
continue
|
| 521 |
+
row = {
|
| 522 |
+
"sentence_id": r["sentence_id"],
|
| 523 |
+
"predicted_class": r["prediction"]["predicted_class"],
|
| 524 |
+
"confidence": r["prediction"]["confidence"],
|
| 525 |
+
"class_name": r["class_description"]["name"],
|
| 526 |
+
"severity_level": r["additional_predictions"]["predicted_severity_level"],
|
| 527 |
+
"fluency_score": r["additional_predictions"]["fluency_score"],
|
| 528 |
+
"fluency_rating": r["additional_predictions"]["fluency_rating"],
|
| 529 |
+
"input_text": r["input_text"],
|
| 530 |
+
}
|
| 531 |
+
for a_type, info in r["probability_distribution"].items():
|
| 532 |
+
row[f"prob_{a_type}"] = info["probability"]
|
| 533 |
+
rows.append(row)
|
| 534 |
+
if not rows:
|
| 535 |
+
return None
|
| 536 |
+
df = pd.DataFrame(rows)
|
| 537 |
+
df.to_csv(os.path.join(output_dir, "detailed_predictions.csv"), index=False, encoding="utf-8")
|
| 538 |
+
summary_stats = {
|
| 539 |
+
"total_predictions": int(len(rows)),
|
| 540 |
+
"class_distribution": df["predicted_class"].value_counts().to_dict(),
|
| 541 |
+
"average_confidence": float(df["confidence"].mean()),
|
| 542 |
+
"confidence_std": float(df["confidence"].std()),
|
| 543 |
+
"average_fluency": float(df["fluency_score"].mean()),
|
| 544 |
+
"fluency_std": float(df["fluency_score"].std()),
|
| 545 |
+
"severity_distribution": df["severity_level"].value_counts().to_dict(),
|
| 546 |
+
}
|
| 547 |
+
with open(os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8") as f:
|
| 548 |
+
json.dump(summary_stats, f, ensure_ascii=False, indent=2)
|
| 549 |
+
print(f"詳細報告已生成並保存到: {output_dir}")
|
| 550 |
+
return df
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
# =========================
|
| 554 |
+
# Convenience: run directly or from pipeline
|
| 555 |
+
# =========================
|
| 556 |
+
|
| 557 |
+
def predict_from_chajson(model_dir: str, chajson_path: str, output_file: Optional[str] = None) -> Dict:
|
| 558 |
+
"""
|
| 559 |
+
Convenience entry:
|
| 560 |
+
- Accepts the JSON produced by cha_json.py
|
| 561 |
+
- If it contains 'sentences', runs per-sentence like before
|
| 562 |
+
- If it only contains 'text_all', creates a single pseudo-sentence
|
| 563 |
+
"""
|
| 564 |
+
with open(chajson_path, "r", encoding="utf-8") as f:
|
| 565 |
+
data = json.load(f)
|
| 566 |
+
|
| 567 |
+
inf = AphasiaInferenceSystem(model_dir)
|
| 568 |
+
|
| 569 |
+
# If there are sentences, use the full path
|
| 570 |
+
if data.get("sentences"):
|
| 571 |
+
return inf.predict_batch(chajson_path, output_file=output_file)
|
| 572 |
+
|
| 573 |
+
# Else, fall back to a single synthetic sentence using text_all
|
| 574 |
+
text_all = data.get("text_all", "")
|
| 575 |
+
fake = {
|
| 576 |
+
"sentences": [{
|
| 577 |
+
"sentence_id": "S1",
|
| 578 |
+
"dialogues": [{
|
| 579 |
+
"INV": [],
|
| 580 |
+
"PAR": [{"tokens": text_all.split(),
|
| 581 |
+
"word_pos_ids": [0]*len(text_all.split()),
|
| 582 |
+
"word_grammar_ids": [[0,0,0]]*len(text_all.split()),
|
| 583 |
+
"word_durations": [0.0]*len(text_all.split())}]
|
| 584 |
+
}]
|
| 585 |
+
}]
|
| 586 |
+
}
|
| 587 |
+
tmp_path = chajson_path + "._synthetic.json"
|
| 588 |
+
with open(tmp_path, "w", encoding="utf-8") as f:
|
| 589 |
+
json.dump(fake, f, ensure_ascii=False, indent=2)
|
| 590 |
+
out = inf.predict_batch(tmp_path, output_file=output_file)
|
| 591 |
+
try:
|
| 592 |
+
os.remove(tmp_path)
|
| 593 |
+
except Exception:
|
| 594 |
+
pass
|
| 595 |
+
return out
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
# ---------- CLI ----------
|
| 599 |
+
|
| 600 |
+
def main():
|
| 601 |
+
import argparse
|
| 602 |
+
p = argparse.ArgumentParser(description="失語症分類推理系統")
|
| 603 |
+
p.add_argument("--model_dir", type=str, required=False, default="./adaptive_aphasia_model",
|
| 604 |
+
help="訓練好的模型目錄路徑")
|
| 605 |
+
p.add_argument("--input_file", type=str, required=True,
|
| 606 |
+
help="輸入JSON文件(cha_json 的輸出)")
|
| 607 |
+
p.add_argument("--output_file", type=str, default="./aphasia_predictions.json",
|
| 608 |
+
help="輸出JSON文件路徑")
|
| 609 |
+
p.add_argument("--report_dir", type=str, default="./inference_results",
|
| 610 |
+
help="詳細報告輸出目錄")
|
| 611 |
+
p.add_argument("--generate_report", action="store_true",
|
| 612 |
+
help="是否生成詳細的CSV報告")
|
| 613 |
+
args = p.parse_args()
|
| 614 |
+
|
| 615 |
+
try:
|
| 616 |
+
print("正在初始化推理系統...")
|
| 617 |
+
sys = AphasiaInferenceSystem(args.model_dir)
|
| 618 |
+
|
| 619 |
+
print("開始執行批次預測...")
|
| 620 |
+
results = sys.predict_batch(args.input_file, args.output_file)
|
| 621 |
+
|
| 622 |
+
if args.generate_report:
|
| 623 |
+
print("生成詳細報告...")
|
| 624 |
+
sys.generate_detailed_report(results["predictions"], args.report_dir)
|
| 625 |
+
|
| 626 |
+
print("\n=== 預測摘要 ===")
|
| 627 |
+
s = results["summary"]
|
| 628 |
+
print(f"總句子數: {results['total_sentences']}")
|
| 629 |
+
print(f"平均信心度: {s.get('average_confidence', 'N/A')}")
|
| 630 |
+
print(f"平均流利度: {s.get('average_fluency_score', 'N/A')}")
|
| 631 |
+
print(f"最常見預測: {s.get('most_common_prediction', 'N/A')}")
|
| 632 |
+
print("\n類別分佈:")
|
| 633 |
+
for name, count in s.get("classification_distribution", {}).items():
|
| 634 |
+
pct = s.get("classification_percentages", {}).get(name, "0%")
|
| 635 |
+
print(f" {name}: {count} ({pct})")
|
| 636 |
+
print(f"\n結果已保存到: {args.output_file}")
|
| 637 |
+
except Exception as e:
|
| 638 |
+
print(f"錯誤: {str(e)}")
|
| 639 |
+
import traceback; traceback.print_exc()
|
| 640 |
+
|
| 641 |
+
if __name__ == "__main__":
|
| 642 |
+
main()
|