Ellie5757575757 commited on
Commit
de21f73
·
verified ·
1 Parent(s): 5908912

Rename Json__Output.py to output.py

Browse files
Files changed (2) hide show
  1. Json__Output.py +0 -896
  2. output.py +642 -0
Json__Output.py DELETED
@@ -1,896 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- 失語症分類推理系統
4
- 用於載入訓練好的模型並對新的語音數據進行分類預測
5
- """
6
-
7
- import json
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- import numpy as np
12
- import os
13
- import math
14
- from typing import Dict, List, Optional, Tuple
15
- from dataclasses import dataclass
16
- import pandas as pd
17
- from transformers import AutoTokenizer, AutoModel
18
- from collections import defaultdict
19
-
20
- # 重新定義模型結構(與訓練程式碼一致)
21
- @dataclass
22
- class ModelConfig:
23
- model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
24
- max_length: int = 512
25
- hidden_size: int = 768
26
- pos_vocab_size: int = 150
27
- pos_emb_dim: int = 64
28
- grammar_dim: int = 3
29
- grammar_hidden_dim: int = 64
30
- duration_hidden_dim: int = 128
31
- prosody_dim: int = 32
32
- num_attention_heads: int = 8
33
- attention_dropout: float = 0.3
34
- classifier_hidden_dims: List[int] = None
35
- dropout_rate: float = 0.3
36
-
37
- def __post_init__(self):
38
- if self.classifier_hidden_dims is None:
39
- self.classifier_hidden_dims = [512, 256]
40
-
41
- class StablePositionalEncoding(nn.Module):
42
- def __init__(self, d_model: int, max_len: int = 5000):
43
- super().__init__()
44
- self.d_model = d_model
45
-
46
- pe = torch.zeros(max_len, d_model)
47
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
48
- div_term = torch.exp(torch.arange(0, d_model, 2).float() *
49
- (-math.log(10000.0) / d_model))
50
-
51
- pe[:, 0::2] = torch.sin(position * div_term)
52
- pe[:, 1::2] = torch.cos(position * div_term)
53
-
54
- self.register_buffer('pe', pe.unsqueeze(0))
55
- self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
56
-
57
- def forward(self, x):
58
- seq_len = x.size(1)
59
- sinusoidal = self.pe[:, :seq_len, :].to(x.device)
60
- learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
61
- return x + 0.1 * (sinusoidal + learnable)
62
-
63
- class StableMultiHeadAttention(nn.Module):
64
- def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
65
- super().__init__()
66
- self.num_heads = num_heads
67
- self.feature_dim = feature_dim
68
- self.head_dim = feature_dim // num_heads
69
-
70
- assert feature_dim % num_heads == 0
71
-
72
- self.query = nn.Linear(feature_dim, feature_dim)
73
- self.key = nn.Linear(feature_dim, feature_dim)
74
- self.value = nn.Linear(feature_dim, feature_dim)
75
- self.dropout = nn.Dropout(dropout)
76
- self.output_proj = nn.Linear(feature_dim, feature_dim)
77
- self.layer_norm = nn.LayerNorm(feature_dim)
78
-
79
- def forward(self, x, mask=None):
80
- batch_size, seq_len, _ = x.size()
81
-
82
- Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
83
- K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
84
- V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
85
-
86
- scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
87
-
88
- if mask is not None:
89
- if mask.dim() == 2:
90
- mask = mask.unsqueeze(1).unsqueeze(1)
91
- scores.masked_fill_(mask == 0, -1e9)
92
-
93
- attn_weights = F.softmax(scores, dim=-1)
94
- attn_weights = self.dropout(attn_weights)
95
-
96
- context = torch.matmul(attn_weights, V)
97
- context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)
98
-
99
- output = self.output_proj(context)
100
- return self.layer_norm(output + x)
101
-
102
- class StableLinguisticFeatureExtractor(nn.Module):
103
- def __init__(self, config: ModelConfig):
104
- super().__init__()
105
- self.config = config
106
-
107
- self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0)
108
- self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4)
109
-
110
- self.grammar_projection = nn.Sequential(
111
- nn.Linear(config.grammar_dim, config.grammar_hidden_dim),
112
- nn.Tanh(),
113
- nn.LayerNorm(config.grammar_hidden_dim),
114
- nn.Dropout(config.dropout_rate * 0.3)
115
- )
116
-
117
- self.duration_projection = nn.Sequential(
118
- nn.Linear(1, config.duration_hidden_dim),
119
- nn.Tanh(),
120
- nn.LayerNorm(config.duration_hidden_dim)
121
- )
122
-
123
- self.prosody_projection = nn.Sequential(
124
- nn.Linear(config.prosody_dim, config.prosody_dim),
125
- nn.ReLU(),
126
- nn.LayerNorm(config.prosody_dim)
127
- )
128
-
129
- total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
130
- config.duration_hidden_dim + config.prosody_dim)
131
- self.feature_fusion = nn.Sequential(
132
- nn.Linear(total_feature_dim, total_feature_dim // 2),
133
- nn.Tanh(),
134
- nn.LayerNorm(total_feature_dim // 2),
135
- nn.Dropout(config.dropout_rate)
136
- )
137
-
138
- def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
139
- batch_size, seq_len = pos_ids.size()
140
-
141
- pos_ids_clamped = pos_ids.clamp(0, self.config.pos_vocab_size - 1)
142
- pos_embeds = self.pos_embedding(pos_ids_clamped)
143
- pos_features = self.pos_attention(pos_embeds, attention_mask)
144
-
145
- grammar_features = self.grammar_projection(grammar_ids.float())
146
- duration_features = self.duration_projection(durations.unsqueeze(-1).float())
147
- prosody_features = self.prosody_projection(prosody_features.float())
148
-
149
- combined_features = torch.cat([
150
- pos_features, grammar_features, duration_features, prosody_features
151
- ], dim=-1)
152
-
153
- fused_features = self.feature_fusion(combined_features)
154
-
155
- mask_expanded = attention_mask.unsqueeze(-1).float()
156
- pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1)
157
-
158
- return pooled_features
159
-
160
- class StableAphasiaClassifier(nn.Module):
161
- def __init__(self, config: ModelConfig, num_labels: int):
162
- super().__init__()
163
- self.config = config
164
- self.num_labels = num_labels
165
-
166
- self.bert = AutoModel.from_pretrained(config.model_name)
167
- self.bert_config = self.bert.config
168
-
169
- self.positional_encoder = StablePositionalEncoding(
170
- d_model=self.bert_config.hidden_size,
171
- max_len=config.max_length
172
- )
173
-
174
- self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
175
-
176
- bert_dim = self.bert_config.hidden_size
177
- linguistic_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
178
- config.duration_hidden_dim + config.prosody_dim) // 2
179
-
180
- self.feature_fusion = nn.Sequential(
181
- nn.Linear(bert_dim + linguistic_dim, bert_dim),
182
- nn.LayerNorm(bert_dim),
183
- nn.Tanh(),
184
- nn.Dropout(config.dropout_rate)
185
- )
186
-
187
- self.classifier = self._build_classifier(bert_dim, num_labels)
188
-
189
- self.severity_head = nn.Sequential(
190
- nn.Linear(bert_dim, 4),
191
- nn.Softmax(dim=-1)
192
- )
193
-
194
- self.fluency_head = nn.Sequential(
195
- nn.Linear(bert_dim, 1),
196
- nn.Sigmoid()
197
- )
198
-
199
- def _build_classifier(self, input_dim: int, num_labels: int):
200
- layers = []
201
- current_dim = input_dim
202
-
203
- for hidden_dim in self.config.classifier_hidden_dims:
204
- layers.extend([
205
- nn.Linear(current_dim, hidden_dim),
206
- nn.LayerNorm(hidden_dim),
207
- nn.Tanh(),
208
- nn.Dropout(self.config.dropout_rate)
209
- ])
210
- current_dim = hidden_dim
211
-
212
- layers.append(nn.Linear(current_dim, num_labels))
213
- return nn.Sequential(*layers)
214
-
215
- def forward(self, input_ids, attention_mask, labels=None,
216
- word_pos_ids=None, word_grammar_ids=None, word_durations=None,
217
- prosody_features=None, **kwargs):
218
-
219
- bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
220
- sequence_output = bert_outputs.last_hidden_state
221
-
222
- position_enhanced = self.positional_encoder(sequence_output)
223
- pooled_output = self._attention_pooling(position_enhanced, attention_mask)
224
-
225
- if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
226
- if prosody_features is None:
227
- batch_size, seq_len = input_ids.size()
228
- prosody_features = torch.zeros(
229
- batch_size, seq_len, self.config.prosody_dim,
230
- device=input_ids.device
231
- )
232
-
233
- linguistic_features = self.linguistic_extractor(
234
- word_pos_ids, word_grammar_ids, word_durations,
235
- prosody_features, attention_mask
236
- )
237
- else:
238
- linguistic_features = torch.zeros(
239
- input_ids.size(0),
240
- (self.config.pos_emb_dim + self.config.grammar_hidden_dim +
241
- self.config.duration_hidden_dim + self.config.prosody_dim) // 2,
242
- device=input_ids.device
243
- )
244
-
245
- combined_features = torch.cat([pooled_output, linguistic_features], dim=1)
246
- fused_features = self.feature_fusion(combined_features)
247
-
248
- logits = self.classifier(fused_features)
249
- severity_pred = self.severity_head(fused_features)
250
- fluency_pred = self.fluency_head(fused_features)
251
-
252
- return {
253
- "logits": logits,
254
- "severity_pred": severity_pred,
255
- "fluency_pred": fluency_pred,
256
- "loss": None
257
- }
258
-
259
- def _attention_pooling(self, sequence_output, attention_mask):
260
- attention_weights = torch.softmax(
261
- torch.sum(sequence_output, dim=-1, keepdim=True), dim=1
262
- )
263
- attention_weights = attention_weights * attention_mask.unsqueeze(-1).float()
264
- attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9)
265
- pooled = torch.sum(sequence_output * attention_weights, dim=1)
266
- return pooled
267
-
268
-
269
- class AphasiaInferenceSystem:
270
- """失語症分類推理系統"""
271
-
272
- def __init__(self, model_dir: str):
273
- """
274
- 初始化推理系統
275
- Args:
276
- model_dir: 訓練好的模型目錄路徑
277
- """
278
- self.model_dir = '/workspace/SH001/adaptive_aphasia_model'
279
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280
-
281
- # 失語症類型描述
282
- self.aphasia_descriptions = {
283
- "BROCA": {
284
- "name": "Broca's Aphasia (Non-fluent)",
285
- "description": "Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.",
286
- "features": ["Non-fluent speech", "Preserved comprehension", "Grammar difficulties", "Word-finding problems"]
287
- },
288
- "TRANSMOTOR": {
289
- "name": "Trans-cortical Motor Aphasia",
290
- "description": "Similar to Broca's aphasia but with preserved repetition abilities. Speech is non-fluent with good comprehension.",
291
- "features": ["Non-fluent speech", "Good repetition", "Preserved comprehension", "Grammar difficulties"]
292
- },
293
- "NOTAPHASICBYWAB": {
294
- "name": "Not Aphasic by WAB",
295
- "description": "Individuals who do not meet the criteria for aphasia according to the Western Aphasia Battery assessment.",
296
- "features": ["Normal language function", "No significant language impairment", "Good comprehension", "Fluent speech"]
297
- },
298
- "CONDUCTION": {
299
- "name": "Conduction Aphasia",
300
- "description": "Characterized by fluent speech with good comprehension but severely impaired repetition. Often involves phonemic paraphasias.",
301
- "features": ["Fluent speech", "Good comprehension", "Poor repetition", "Phonemic errors"]
302
- },
303
- "WERNICKE": {
304
- "name": "Wernicke's Aphasia (Fluent)",
305
- "description": "Fluent but often meaningless speech with poor comprehension. Speech may contain neologisms and jargon.",
306
- "features": ["Fluent speech", "Poor comprehension", "Jargon speech", "Neologisms"]
307
- },
308
- "ANOMIC": {
309
- "name": "Anomic Aphasia",
310
- "description": "Primarily characterized by word-finding difficulties with otherwise relatively preserved language abilities.",
311
- "features": ["Word-finding difficulties", "Good comprehension", "Fluent speech", "Circumlocution"]
312
- },
313
- "GLOBAL": {
314
- "name": "Global Aphasia",
315
- "description": "Severe impairment in all language modalities - comprehension, production, repetition, and naming.",
316
- "features": ["Severe comprehension deficit", "Non-fluent speech", "Poor repetition", "Severe naming difficulties"]
317
- },
318
- "ISOLATION": {
319
- "name": "Isolation Syndrome",
320
- "description": "Rare condition with preserved repetition but severely impaired comprehension and spontaneous speech.",
321
- "features": ["Good repetition", "Poor comprehension", "Limited spontaneous speech", "Echolalia"]
322
- },
323
- "TRANSSENSORY": {
324
- "name": "Trans-cortical Sensory Aphasia",
325
- "description": "Fluent speech with good repetition but impaired comprehension, similar to Wernicke's but with preserved repetition.",
326
- "features": ["Fluent speech", "Good repetition", "Poor comprehension", "Semantic errors"]
327
- }
328
- }
329
-
330
- # 載入模型配置和映射
331
- self.load_configuration()
332
-
333
- # 載入模型
334
- self.load_model()
335
-
336
- print(f"推理系統初始化完成,使用設備: {self.device}")
337
-
338
- def load_configuration(self):
339
- """載入模型配置"""
340
- config_path = os.path.join(self.model_dir, "config.json")
341
- if os.path.exists(config_path):
342
- with open(config_path, "r", encoding="utf-8") as f:
343
- config_data = json.load(f)
344
-
345
- self.aphasia_types_mapping = config_data.get("aphasia_types_mapping", {
346
- "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
347
- "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
348
- "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
349
- })
350
- self.num_labels = config_data.get("num_labels", 9)
351
- self.model_name = config_data.get("model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
352
- else:
353
- # 預設配置
354
- self.aphasia_types_mapping = {
355
- "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
356
- "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
357
- "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
358
- }
359
- self.num_labels = 9
360
- self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
361
-
362
- # 建立反向映射
363
- self.id_to_aphasia_type = {v: k for k, v in self.aphasia_types_mapping.items()}
364
-
365
- def load_model(self):
366
- """載入訓練好的模型"""
367
- # 載入tokenizer
368
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
369
- if self.tokenizer.pad_token is None:
370
- self.tokenizer.pad_token = self.tokenizer.eos_token
371
- added_tokens_path = os.path.join(self.model_dir, "added_tokens.json")
372
- if os.path.exists(added_tokens_path):
373
- with open(added_tokens_path, "r", encoding="utf-8") as f:
374
- data = json.load(f)
375
- # 如果是 dict,就取出所有 key 當作要新增的 token 清單
376
- if isinstance(data, dict):
377
- tokens = list(data.keys())
378
- else:
379
- tokens = data # 萬一已經是 list,就直接用
380
- num_added = self.tokenizer.add_tokens(tokens)
381
- print(f"新增到 tokenizer 的 token 數量: {num_added}")
382
- # 建立模型配置
383
- self.config = ModelConfig()
384
- self.config.model_name = self.model_name
385
-
386
- # 建立模型
387
- self.model = StableAphasiaClassifier(self.config, self.num_labels)
388
- self.model.bert.resize_token_embeddings(len(self.tokenizer))
389
- # 載入模型權重
390
- model_path = os.path.join(self.model_dir, "pytorch_model.bin")
391
- if os.path.exists(model_path):
392
- state_dict = torch.load(model_path, map_location=self.device)
393
- self.model.load_state_dict(state_dict)
394
- self.model.load_state_dict(state_dict)
395
- print("模型權重載入成功")
396
- else:
397
- raise FileNotFoundError(f"模型權重文件不存在: {model_path}")
398
-
399
- # 調整tokenizer尺寸
400
- self.model.bert.resize_token_embeddings(len(self.tokenizer))
401
-
402
- # 移動到設備並設置為評估模式
403
- self.model.to(self.device)
404
- self.model.eval()
405
-
406
- def preprocess_sentence(self, sentence_data: dict) -> dict:
407
- """預處理單個句子數據"""
408
- all_tokens, all_pos, all_grammar, all_durations = [], [], [], []
409
-
410
- # 處理對話數據
411
- for dialogue_idx, dialogue in enumerate(sentence_data.get("dialogues", [])):
412
- if dialogue_idx > 0:
413
- all_tokens.append("[DIALOGUE]")
414
- all_pos.append(0)
415
- all_grammar.append([0, 0, 0])
416
- all_durations.append(0.0)
417
-
418
- # 處理參與者的語音
419
- for par in dialogue.get("PAR", []):
420
- if "tokens" in par and par["tokens"]:
421
- tokens = par["tokens"]
422
- pos_ids = par.get("word_pos_ids", [0] * len(tokens))
423
- grammar_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(tokens))
424
- durations = par.get("word_durations", [0.0] * len(tokens))
425
-
426
- all_tokens.extend(tokens)
427
- all_pos.extend(pos_ids)
428
- all_grammar.extend(grammar_ids)
429
- all_durations.extend(durations)
430
-
431
- if not all_tokens:
432
- return None
433
-
434
- # 文本tokenization
435
- text = " ".join(all_tokens)
436
- encoded = self.tokenizer(
437
- text,
438
- max_length=self.config.max_length,
439
- padding="max_length",
440
- truncation=True,
441
- return_tensors="pt"
442
- )
443
-
444
- # 對齊特徵
445
- aligned_pos, aligned_grammar, aligned_durations = self._align_features(
446
- all_tokens, all_pos, all_grammar, all_durations, encoded
447
- )
448
-
449
- # 建立韻律特徵
450
- prosody_features = self._extract_prosodic_features(all_durations, all_tokens)
451
- prosody_tensor = torch.tensor(prosody_features).unsqueeze(0).repeat(
452
- self.config.max_length, 1
453
- )
454
-
455
- return {
456
- "input_ids": encoded["input_ids"].squeeze(0),
457
- "attention_mask": encoded["attention_mask"].squeeze(0),
458
- "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long),
459
- "word_grammar_ids": torch.tensor(aligned_grammar, dtype=torch.long),
460
- "word_durations": torch.tensor(aligned_durations, dtype=torch.float),
461
- "prosody_features": prosody_tensor.float(),
462
- "sentence_id": sentence_data.get("sentence_id", "unknown"),
463
- "original_tokens": all_tokens,
464
- "text": text
465
- }
466
-
467
- def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded):
468
- """對齊特徵與BERT子詞"""
469
- subtoken_to_token = []
470
-
471
- for token_idx, token in enumerate(tokens):
472
- subtokens = self.tokenizer.tokenize(token)
473
- subtoken_to_token.extend([token_idx] * len(subtokens))
474
-
475
- aligned_pos = [0] # [CLS]
476
- aligned_grammar = [[0, 0, 0]] # [CLS]
477
- aligned_durations = [0.0] # [CLS]
478
-
479
- for subtoken_idx in range(1, self.config.max_length - 1):
480
- if subtoken_idx - 1 < len(subtoken_to_token):
481
- original_idx = subtoken_to_token[subtoken_idx - 1]
482
- aligned_pos.append(pos_ids[original_idx] if original_idx < len(pos_ids) else 0)
483
- aligned_grammar.append(grammar_ids[original_idx] if original_idx < len(grammar_ids) else [0, 0, 0])
484
-
485
- # 處理duration數據
486
- raw_duration = durations[original_idx] if original_idx < len(durations) else 0.0
487
- if isinstance(raw_duration, list) and len(raw_duration) >= 2:
488
- try:
489
- duration_val = float(raw_duration[1]) - float(raw_duration[0])
490
- except (ValueError, TypeError):
491
- duration_val = 0.0
492
- elif isinstance(raw_duration, (int, float)):
493
- duration_val = float(raw_duration)
494
- else:
495
- duration_val = 0.0
496
-
497
- aligned_durations.append(duration_val)
498
- else:
499
- aligned_pos.append(0)
500
- aligned_grammar.append([0, 0, 0])
501
- aligned_durations.append(0.0)
502
-
503
- aligned_pos.append(0) # [SEP]
504
- aligned_grammar.append([0, 0, 0]) # [SEP]
505
- aligned_durations.append(0.0) # [SEP]
506
-
507
- return aligned_pos, aligned_grammar, aligned_durations
508
-
509
- def _extract_prosodic_features(self, durations, tokens):
510
- """提取韻律特徵"""
511
- if not durations:
512
- return [0.0] * self.config.prosody_dim
513
-
514
- # 處理duration數據並提取數值
515
- processed_durations = []
516
- for d in durations:
517
- if isinstance(d, list) and len(d) >= 2:
518
- try:
519
- processed_durations.append(float(d[1]) - float(d[0]))
520
- except (ValueError, TypeError):
521
- continue
522
- elif isinstance(d, (int, float)):
523
- processed_durations.append(float(d))
524
-
525
- if not processed_durations:
526
- return [0.0] * self.config.prosody_dim
527
-
528
- # 計算基本統計特徵
529
- features = [
530
- np.mean(processed_durations),
531
- np.std(processed_durations),
532
- np.median(processed_durations),
533
- len([d for d in processed_durations if d > np.mean(processed_durations) * 1.5])
534
- ]
535
-
536
- # 填充至所需維度
537
- while len(features) < self.config.prosody_dim:
538
- features.append(0.0)
539
-
540
- return features[:self.config.prosody_dim]
541
-
542
- def predict_single(self, sentence_data: dict) -> dict:
543
- """對單個句子進行預測"""
544
- # 預處理數據
545
- processed_data = self.preprocess_sentence(sentence_data)
546
- if processed_data is None:
547
- return {
548
- "error": "無法處理輸入數據",
549
- "sentence_id": sentence_data.get("sentence_id", "unknown")
550
- }
551
-
552
- # 準備輸入數據
553
- input_data = {
554
- "input_ids": processed_data["input_ids"].unsqueeze(0).to(self.device),
555
- "attention_mask": processed_data["attention_mask"].unsqueeze(0).to(self.device),
556
- "word_pos_ids": processed_data["word_pos_ids"].unsqueeze(0).to(self.device),
557
- "word_grammar_ids": processed_data["word_grammar_ids"].unsqueeze(0).to(self.device),
558
- "word_durations": processed_data["word_durations"].unsqueeze(0).to(self.device),
559
- "prosody_features": processed_data["prosody_features"].unsqueeze(0).to(self.device)
560
- }
561
-
562
- # 模型推理
563
- with torch.no_grad():
564
- outputs = self.model(**input_data)
565
-
566
- logits = outputs["logits"]
567
- probabilities = F.softmax(logits, dim=1).cpu().numpy()[0]
568
- predicted_class_id = np.argmax(probabilities)
569
-
570
- severity_pred = outputs["severity_pred"].cpu().numpy()[0]
571
- fluency_pred = outputs["fluency_pred"].cpu().numpy()[0][0]
572
-
573
- # 建立結果
574
- predicted_type = self.id_to_aphasia_type[predicted_class_id]
575
- confidence = float(probabilities[predicted_class_id])
576
-
577
- # 建立機率分佈
578
- probability_distribution = {}
579
- for aphasia_type, type_id in self.aphasia_types_mapping.items():
580
- probability_distribution[aphasia_type] = {
581
- "probability": float(probabilities[type_id]),
582
- "percentage": f"{probabilities[type_id]*100:.2f}%"
583
- }
584
-
585
- # 排序機率分佈
586
- sorted_probabilities = sorted(
587
- probability_distribution.items(),
588
- key=lambda x: x[1]["probability"],
589
- reverse=True
590
- )
591
-
592
- result = {
593
- "sentence_id": processed_data["sentence_id"],
594
- "input_text": processed_data["text"],
595
- "original_tokens": processed_data["original_tokens"],
596
- "prediction": {
597
- "predicted_class": predicted_type,
598
- "confidence": confidence,
599
- "confidence_percentage": f"{confidence*100:.2f}%"
600
- },
601
- "class_description": self.aphasia_descriptions.get(predicted_type, {
602
- "name": predicted_type,
603
- "description": "Description not available",
604
- "features": []
605
- }),
606
- "probability_distribution": dict(sorted_probabilities),
607
- "additional_predictions": {
608
- "severity_distribution": {
609
- "level_0": float(severity_pred[0]),
610
- "level_1": float(severity_pred[1]),
611
- "level_2": float(severity_pred[2]),
612
- "level_3": float(severity_pred[3])
613
- },
614
- "predicted_severity_level": int(np.argmax(severity_pred)),
615
- "fluency_score": float(fluency_pred),
616
- "fluency_rating": "High" if fluency_pred > 0.7 else "Medium" if fluency_pred > 0.4 else "Low"
617
- }
618
- }
619
-
620
- return result
621
-
622
- def predict_batch(self, input_file: str, output_file: str = None) -> List[dict]:
623
- """批次預測JSON文件中的所有句子"""
624
- # 載入輸入文件
625
- with open(input_file, "r", encoding="utf-8") as f:
626
- data = json.load(f)
627
-
628
- sentences = data.get("sentences", [])
629
- results = []
630
-
631
- print(f"開始處理 {len(sentences)} 個句子...")
632
-
633
- for i, sentence in enumerate(sentences):
634
- print(f"處理第 {i+1}/{len(sentences)} 個句子...")
635
- result = self.predict_single(sentence)
636
- results.append(result)
637
-
638
- # 建立摘要統計
639
- summary = self._generate_summary(results)
640
-
641
- final_output = {
642
- "summary": summary,
643
- "total_sentences": len(results),
644
- "predictions": results
645
- }
646
-
647
- # 保存結果
648
- if output_file:
649
- with open(output_file, "w", encoding="utf-8") as f:
650
- json.dump(final_output, f, ensure_ascii=False, indent=2)
651
- print(f"結果已保存到: {output_file}")
652
-
653
- return final_output
654
-
655
- def _generate_summary(self, results: List[dict]) -> dict:
656
- """生成預測結果摘要"""
657
- if not results:
658
- return {}
659
-
660
- # 統計各類別預測數量
661
- class_counts = defaultdict(int)
662
- confidence_scores = []
663
- fluency_scores = []
664
- severity_levels = defaultdict(int)
665
-
666
- for result in results:
667
- if "error" not in result:
668
- predicted_class = result["prediction"]["predicted_class"]
669
- confidence = result["prediction"]["confidence"]
670
- fluency = result["additional_predictions"]["fluency_score"]
671
- severity = result["additional_predictions"]["predicted_severity_level"]
672
-
673
- class_counts[predicted_class] += 1
674
- confidence_scores.append(confidence)
675
- fluency_scores.append(fluency)
676
- severity_levels[severity] += 1
677
-
678
- # 計算統計數據
679
- avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
680
- avg_fluency = np.mean(fluency_scores) if fluency_scores else 0
681
-
682
- summary = {
683
- "classification_distribution": dict(class_counts),
684
- "classification_percentages": {
685
- k: f"{v/len(results)*100:.1f}%"
686
- for k, v in class_counts.items()
687
- },
688
- "average_confidence": f"{avg_confidence:.3f}",
689
- "average_fluency_score": f"{avg_fluency:.3f}",
690
- "severity_distribution": dict(severity_levels),
691
- "confidence_statistics": {
692
- "mean": f"{np.mean(confidence_scores):.3f}",
693
- "std": f"{np.std(confidence_scores):.3f}",
694
- "min": f"{np.min(confidence_scores):.3f}",
695
- "max": f"{np.max(confidence_scores):.3f}"
696
- } if confidence_scores else {},
697
- "most_common_prediction": max(class_counts.items(), key=lambda x: x[1])[0] if class_counts else "None"
698
- }
699
-
700
- return summary
701
-
702
- def generate_detailed_report(self, results: List[dict], output_dir: str = "./inference_results"):
703
- """生成詳細的分析報告"""
704
- os.makedirs(output_dir, exist_ok=True)
705
-
706
- # 建立詳細的CSV報告
707
- report_data = []
708
- for result in results:
709
- if "error" not in result:
710
- row = {
711
- "sentence_id": result["sentence_id"],
712
- "predicted_class": result["prediction"]["predicted_class"],
713
- "confidence": result["prediction"]["confidence"],
714
- "class_name": result["class_description"]["name"],
715
- "severity_level": result["additional_predictions"]["predicted_severity_level"],
716
- "fluency_score": result["additional_predictions"]["fluency_score"],
717
- "fluency_rating": result["additional_predictions"]["fluency_rating"],
718
- "input_text": result["input_text"]
719
- }
720
-
721
- # 添加各類別機率
722
- for aphasia_type in self.aphasia_types_mapping.keys():
723
- row[f"prob_{aphasia_type}"] = result["probability_distribution"][aphasia_type]["probability"]
724
-
725
- report_data.append(row)
726
-
727
- # 保存CSV
728
- if report_data:
729
- df = pd.DataFrame(report_data)
730
- df.to_csv(os.path.join(output_dir, "detailed_predictions.csv"), index=False, encoding='utf-8')
731
-
732
- # 生成統計摘要
733
- summary_stats = {
734
- "total_predictions": len(report_data),
735
- "class_distribution": df["predicted_class"].value_counts().to_dict(),
736
- "average_confidence": df["confidence"].mean(),
737
- "confidence_std": df["confidence"].std(),
738
- "average_fluency": df["fluency_score"].mean(),
739
- "fluency_std": df["fluency_score"].std(),
740
- "severity_distribution": df["severity_level"].value_counts().to_dict()
741
- }
742
-
743
- with open(os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8") as f:
744
- json.dump(summary_stats, f, ensure_ascii=False, indent=2)
745
-
746
- print(f"詳細報告已生成並保存到: {output_dir}")
747
- return df
748
-
749
- return None
750
-
751
-
752
- def main():
753
- """主程式 - 命令行介面"""
754
- import argparse
755
-
756
- parser = argparse.ArgumentParser(description="失語症分類推理系統")
757
- parser.add_argument("--model_dir", type=str, default = '/workspace/SH001/adaptive_aphasia_model',
758
- help="訓練好的模型目錄路徑")
759
- parser.add_argument("--input_file", type=str, default = '/workspace/SH001/website/sample.input.json',
760
- help="輸入JSON文件路徑")
761
- parser.add_argument("--output_file", type=str, default="./aphasia_predictions.json",
762
- help="輸出JSON文件路徑")
763
- parser.add_argument("--report_dir", type=str, default="./inference_results",
764
- help="詳細報告輸出目錄")
765
- parser.add_argument("--generate_report", action="store_true",
766
- help="是否生成詳細的CSV報告")
767
-
768
- args = parser.parse_args()
769
-
770
- try:
771
- # 初始化推理系統
772
- print("正在初始化推理系統...")
773
- inference_system = AphasiaInferenceSystem(args.model_dir)
774
-
775
- # 執行批次預測
776
- print("開始執行批次預測...")
777
- results = inference_system.predict_batch(args.input_file, args.output_file)
778
-
779
- # 生成詳細報告
780
- if args.generate_report:
781
- print("生成詳細報告...")
782
- inference_system.generate_detailed_report(results["predictions"], args.report_dir)
783
-
784
- # 顯示摘要
785
- print("\n=== 預測摘要 ===")
786
- summary = results["summary"]
787
- print(f"總句子數: {results['total_sentences']}")
788
- print(f"平均信心度: {summary.get('average_confidence', 'N/A')}")
789
- print(f"平均流利度: {summary.get('average_fluency_score', 'N/A')}")
790
- print(f"最常見預測: {summary.get('most_common_prediction', 'N/A')}")
791
-
792
- print("\n類別分佈:")
793
- for class_name, count in summary.get("classification_distribution", {}).items():
794
- percentage = summary.get("classification_percentages", {}).get(class_name, "0%")
795
- print(f" {class_name}: {count} ({percentage})")
796
-
797
- print(f"\n結果已保存到: {args.output_file}")
798
-
799
- except Exception as e:
800
- print(f"錯誤: {str(e)}")
801
- import traceback
802
- traceback.print_exc()
803
-
804
-
805
- # 使用範例
806
- def example_usage():
807
- """使用範例"""
808
-
809
- # 1. 基本使用
810
- print("=== 失語症分類推理系統使用範例 ===\n")
811
-
812
- # 範例輸入數據
813
- sample_input = {
814
- "sentences": [
815
- {
816
- "sentence_id": "S1",
817
- "aphasia_type": "BROCA", # 這在推理時會被忽略
818
- "dialogues": [
819
- {
820
- "INV": [
821
- {
822
- "tokens": ["how", "are", "you", "feeling"],
823
- "word_pos_ids": [9, 10, 5, 6],
824
- "word_grammar_ids": [[1, 4, 11], [2, 4, 2], [3, 4, 1], [4, 0, 3]],
825
- "word_durations": [["how", 300], ["are", 200], ["you", 150], ["feeling", 500]]
826
- }
827
- ],
828
- "PAR": [
829
- {
830
- "tokens": ["I", "feel", "good"],
831
- "word_pos_ids": [1, 6, 8],
832
- "word_grammar_ids": [[1, 2, 1], [2, 3, 2], [3, 4, 8]],
833
- "word_durations": [["I", 200], ["feel", 400], ["good", 600]]
834
- }
835
- ]
836
- }
837
- ]
838
- }
839
- ]
840
- }
841
-
842
- # 保存範例輸入
843
- with open("sample_input.json", "w", encoding="utf-8") as f:
844
- json.dump(sample_input, f, ensure_ascii=False, indent=2)
845
-
846
- print("範例輸入文件已創建: sample_input.json")
847
-
848
- # 顯示使用說明
849
- usage_instructions = """
850
- 使用方法:
851
-
852
- 1. 命令行使用:
853
- python aphasia_inference.py \\
854
- --model_dir ./adaptive_aphasia_model \\
855
- --input_file sample_input.json \\
856
- --output_file predictions.json \\
857
- --generate_report \\
858
- --report_dir ./results
859
-
860
- 2. Python代碼使用:
861
- from aphasia_inference import AphasiaInferenceSystem
862
-
863
- # 初始化系統
864
- system = AphasiaInferenceSystem("./adaptive_aphasia_model")
865
-
866
- # 單個預測
867
- with open("sample_input.json", "r") as f:
868
- data = json.load(f)
869
- result = system.predict_single(data["sentences"][0])
870
-
871
- # 批次預測
872
- results = system.predict_batch("sample_input.json", "output.json")
873
-
874
- 3. 輸出格式:
875
- - JSON格式包含詳細的預測結果和機率分佈
876
- - CSV格式包含表格化的預測數據
877
- - 統計摘要包含整體分析結果
878
-
879
- 4. 支援的失語症類型:
880
- - BROCA: 布若卡失語症
881
- - WERNICKE: 韋尼克失語症
882
- - ANOMIC: 命名性失語症
883
- - CONDUCTION: 傳導性失語症
884
- - GLOBAL: 全面性失語症
885
- - 以及其他類型...
886
- """
887
-
888
- print(usage_instructions)
889
-
890
-
891
- if __name__ == "__main__":
892
- # 如果作為腳本執行,運行主程式
893
- main()
894
-
895
- # 如果想看使用範例,取消下面這行的註釋
896
- # example_usage()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
output.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Aphasia classification inference (cleaned).
4
+ - Respects model_dir argument
5
+ - Correctly parses durations like ["word", 300] and [start, end]
6
+ - Removes duplicate load_state_dict
7
+ - Adds predict_from_chajson(json_path, ...) helper
8
+ """
9
+
10
+ import json
11
+ import os
12
+ import math
13
+ from dataclasses import dataclass
14
+ from typing import Dict, List, Optional, Tuple
15
+ from collections import defaultdict
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import pandas as pd
22
+ from transformers import AutoTokenizer, AutoModel
23
+
24
+
25
+ # =========================
26
+ # Model definition (unchanged shape)
27
+ # =========================
28
+
29
+ @dataclass
30
+ class ModelConfig:
31
+ model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
32
+ max_length: int = 512
33
+ hidden_size: int = 768
34
+ pos_vocab_size: int = 150
35
+ pos_emb_dim: int = 64
36
+ grammar_dim: int = 3
37
+ grammar_hidden_dim: int = 64
38
+ duration_hidden_dim: int = 128
39
+ prosody_dim: int = 32
40
+ num_attention_heads: int = 8
41
+ attention_dropout: float = 0.3
42
+ classifier_hidden_dims: List[int] = None
43
+ dropout_rate: float = 0.3
44
+ def __post_init__(self):
45
+ if self.classifier_hidden_dims is None:
46
+ self.classifier_hidden_dims = [512, 256]
47
+
48
+ class StablePositionalEncoding(nn.Module):
49
+ def __init__(self, d_model: int, max_len: int = 5000):
50
+ super().__init__()
51
+ self.d_model = d_model
52
+ pe = torch.zeros(max_len, d_model)
53
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
54
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
55
+ pe[:, 0::2] = torch.sin(position * div_term)
56
+ pe[:, 1::2] = torch.cos(position * div_term)
57
+ self.register_buffer('pe', pe.unsqueeze(0))
58
+ self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
59
+ def forward(self, x):
60
+ seq_len = x.size(1)
61
+ sinusoidal = self.pe[:, :seq_len, :].to(x.device)
62
+ learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
63
+ return x + 0.1 * (sinusoidal + learnable)
64
+
65
+ class StableMultiHeadAttention(nn.Module):
66
+ def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
67
+ super().__init__()
68
+ self.num_heads = num_heads
69
+ self.feature_dim = feature_dim
70
+ self.head_dim = feature_dim // num_heads
71
+ assert feature_dim % num_heads == 0
72
+ self.query = nn.Linear(feature_dim, feature_dim)
73
+ self.key = nn.Linear(feature_dim, feature_dim)
74
+ self.value = nn.Linear(feature_dim, feature_dim)
75
+ self.dropout = nn.Dropout(dropout)
76
+ self.output_proj = nn.Linear(feature_dim, feature_dim)
77
+ self.layer_norm = nn.LayerNorm(feature_dim)
78
+ def forward(self, x, mask=None):
79
+ b, t, _ = x.size()
80
+ Q = self.query(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
81
+ K = self.key(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
82
+ V = self.value(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
83
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
84
+ if mask is not None:
85
+ if mask.dim() == 2:
86
+ mask = mask.unsqueeze(1).unsqueeze(1)
87
+ scores.masked_fill_(mask == 0, -1e9)
88
+ attn = F.softmax(scores, dim=-1)
89
+ attn = self.dropout(attn)
90
+ ctx = torch.matmul(attn, V)
91
+ ctx = ctx.transpose(1, 2).contiguous().view(b, t, self.feature_dim)
92
+ out = self.output_proj(ctx)
93
+ return self.layer_norm(out + x)
94
+
95
+ class StableLinguisticFeatureExtractor(nn.Module):
96
+ def __init__(self, config: ModelConfig):
97
+ super().__init__()
98
+ self.config = config
99
+ self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0)
100
+ self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4)
101
+ self.grammar_projection = nn.Sequential(
102
+ nn.Linear(config.grammar_dim, config.grammar_hidden_dim),
103
+ nn.Tanh(),
104
+ nn.LayerNorm(config.grammar_hidden_dim),
105
+ nn.Dropout(config.dropout_rate * 0.3)
106
+ )
107
+ self.duration_projection = nn.Sequential(
108
+ nn.Linear(1, config.duration_hidden_dim),
109
+ nn.Tanh(),
110
+ nn.LayerNorm(config.duration_hidden_dim)
111
+ )
112
+ self.prosody_projection = nn.Sequential(
113
+ nn.Linear(config.prosody_dim, config.prosody_dim),
114
+ nn.ReLU(),
115
+ nn.LayerNorm(config.prosody_dim)
116
+ )
117
+ total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
118
+ config.duration_hidden_dim + config.prosody_dim)
119
+ self.feature_fusion = nn.Sequential(
120
+ nn.Linear(total_feature_dim, total_feature_dim // 2),
121
+ nn.Tanh(),
122
+ nn.LayerNorm(total_feature_dim // 2),
123
+ nn.Dropout(config.dropout_rate)
124
+ )
125
+ def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
126
+ b, t = pos_ids.size()
127
+ pos_ids = pos_ids.clamp(0, self.config.pos_vocab_size - 1)
128
+ pos_emb = self.pos_embedding(pos_ids)
129
+ pos_feat = self.pos_attention(pos_emb, attention_mask)
130
+ gra_feat = self.grammar_projection(grammar_ids.float())
131
+ dur_feat = self.duration_projection(durations.unsqueeze(-1).float())
132
+ pro_feat = self.prosody_projection(prosody_features.float())
133
+ combined = torch.cat([pos_feat, gra_feat, dur_feat, pro_feat], dim=-1)
134
+ fused = self.feature_fusion(combined)
135
+ mask_exp = attention_mask.unsqueeze(-1).float()
136
+ pooled = torch.sum(fused * mask_exp, dim=1) / torch.sum(mask_exp, dim=1)
137
+ return pooled
138
+
139
+ class StableAphasiaClassifier(nn.Module):
140
+ def __init__(self, config: ModelConfig, num_labels: int):
141
+ super().__init__()
142
+ self.config = config
143
+ self.num_labels = num_labels
144
+ self.bert = AutoModel.from_pretrained(config.model_name)
145
+ self.bert_config = self.bert.config
146
+ self.positional_encoder = StablePositionalEncoding(d_model=self.bert_config.hidden_size,
147
+ max_len=config.max_length)
148
+ self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
149
+ bert_dim = self.bert_config.hidden_size
150
+ lingu_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
151
+ config.duration_hidden_dim + config.prosody_dim) // 2
152
+ self.feature_fusion = nn.Sequential(
153
+ nn.Linear(bert_dim + lingu_dim, bert_dim),
154
+ nn.LayerNorm(bert_dim),
155
+ nn.Tanh(),
156
+ nn.Dropout(config.dropout_rate)
157
+ )
158
+ self.classifier = self._build_classifier(bert_dim, num_labels)
159
+ self.severity_head = nn.Sequential(nn.Linear(bert_dim, 4), nn.Softmax(dim=-1))
160
+ self.fluency_head = nn.Sequential(nn.Linear(bert_dim, 1), nn.Sigmoid())
161
+ def _build_classifier(self, input_dim: int, num_labels: int):
162
+ layers, cur = [], input_dim
163
+ for h in self.config.classifier_hidden_dims:
164
+ layers += [nn.Linear(cur, h), nn.LayerNorm(h), nn.Tanh(), nn.Dropout(self.config.dropout_rate)]
165
+ cur = h
166
+ layers.append(nn.Linear(cur, num_labels))
167
+ return nn.Sequential(*layers)
168
+ def _attention_pooling(self, seq_out, attn_mask):
169
+ attn_w = torch.softmax(torch.sum(seq_out, dim=-1, keepdim=True), dim=1)
170
+ attn_w = attn_w * attn_mask.unsqueeze(-1).float()
171
+ attn_w = attn_w / (torch.sum(attn_w, dim=1, keepdim=True) + 1e-9)
172
+ return torch.sum(seq_out * attn_w, dim=1)
173
+ def forward(self, input_ids, attention_mask, labels=None,
174
+ word_pos_ids=None, word_grammar_ids=None, word_durations=None,
175
+ prosody_features=None, **kwargs):
176
+ bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
177
+ seq_out = bert_out.last_hidden_state
178
+ pos_enh = self.positional_encoder(seq_out)
179
+ pooled = self._attention_pooling(pos_enh, attention_mask)
180
+ if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
181
+ if prosody_features is None:
182
+ b, t = input_ids.size()
183
+ prosody_features = torch.zeros(b, t, self.config.prosody_dim, device=input_ids.device)
184
+ ling = self.linguistic_extractor(word_pos_ids, word_grammar_ids, word_durations,
185
+ prosody_features, attention_mask)
186
+ else:
187
+ ling = torch.zeros(input_ids.size(0),
188
+ (self.config.pos_emb_dim + self.config.grammar_hidden_dim +
189
+ self.config.duration_hidden_dim + self.config.prosody_dim) // 2,
190
+ device=input_ids.device)
191
+ fused = self.feature_fusion(torch.cat([pooled, ling], dim=1))
192
+ logits = self.classifier(fused)
193
+ severity_pred = self.severity_head(fused)
194
+ fluency_pred = self.fluency_head(fused)
195
+ return {"logits": logits, "severity_pred": severity_pred, "fluency_pred": fluency_pred, "loss": None}
196
+
197
+
198
+ # =========================
199
+ # Inference system (fixed wiring)
200
+ # =========================
201
+
202
+ class AphasiaInferenceSystem:
203
+ """失語症分類推理系統"""
204
+
205
+ def __init__(self, model_dir: str):
206
+ self.model_dir = model_dir # <— honor the argument
207
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
208
+
209
+ # Descriptions (unchanged)
210
+ self.aphasia_descriptions = {
211
+ "BROCA": {"name": "Broca's Aphasia (Non-fluent)", "description":
212
+ "Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.",
213
+ "features": ["Non-fluent speech", "Preserved comprehension", "Grammar difficulties", "Word-finding problems"]},
214
+ "TRANSMOTOR": {"name": "Trans-cortical Motor Aphasia", "description":
215
+ "Similar to Broca's aphasia but with preserved repetition abilities. Speech is non-fluent with good comprehension.",
216
+ "features": ["Non-fluent speech", "Good repetition", "Preserved comprehension", "Grammar difficulties"]},
217
+ "NOTAPHASICBYWAB": {"name": "Not Aphasic by WAB", "description":
218
+ "Individuals who do not meet the criteria for aphasia according to the Western Aphasia Battery assessment.",
219
+ "features": ["Normal language function", "No significant language impairment", "Good comprehension", "Fluent speech"]},
220
+ "CONDUCTION": {"name": "Conduction Aphasia", "description":
221
+ "Characterized by fluent speech with good comprehension but severely impaired repetition. Often involves phonemic paraphasias.",
222
+ "features": ["Fluent speech", "Good comprehension", "Poor repetition", "Phonemic errors"]},
223
+ "WERNICKE": {"name": "Wernicke's Aphasia (Fluent)", "description":
224
+ "Fluent but often meaningless speech with poor comprehension. Speech may contain neologisms and jargon.",
225
+ "features": ["Fluent speech", "Poor comprehension", "Jargon speech", "Neologisms"]},
226
+ "ANOMIC": {"name": "Anomic Aphasia", "description":
227
+ "Primarily characterized by word-finding difficulties with otherwise relatively preserved language abilities.",
228
+ "features": ["Word-finding difficulties", "Good comprehension", "Fluent speech", "Circumlocution"]},
229
+ "GLOBAL": {"name": "Global Aphasia", "description":
230
+ "Severe impairment in all language modalities - comprehension, production, repetition, and naming.",
231
+ "features": ["Severe comprehension deficit", "Non-fluent speech", "Poor repetition", "Severe naming difficulties"]},
232
+ "ISOLATION": {"name": "Isolation Syndrome", "description":
233
+ "Rare condition with preserved repetition but severely impaired comprehension and spontaneous speech.",
234
+ "features": ["Good repetition", "Poor comprehension", "Limited spontaneous speech", "Echolalia"]},
235
+ "TRANSSENSORY": {"name": "Trans-cortical Sensory Aphasia", "description":
236
+ "Fluent speech with good repetition but impaired comprehension, similar to Wernicke's but with preserved repetition.",
237
+ "features": ["Fluent speech", "Good repetition", "Poor comprehension", "Semantic errors"]}
238
+ }
239
+
240
+ self.load_configuration()
241
+ self.load_model()
242
+ print(f"推理系統初始化完成,使用設備: {self.device}")
243
+
244
+ def load_configuration(self):
245
+ cfg_path = os.path.join(self.model_dir, "config.json")
246
+ if os.path.exists(cfg_path):
247
+ with open(cfg_path, "r", encoding="utf-8") as f:
248
+ cfg = json.load(f)
249
+ self.aphasia_types_mapping = cfg.get("aphasia_types_mapping", {
250
+ "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
251
+ "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
252
+ "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
253
+ })
254
+ self.num_labels = cfg.get("num_labels", 9)
255
+ self.model_name = cfg.get("model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
256
+ else:
257
+ self.aphasia_types_mapping = {
258
+ "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
259
+ "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
260
+ "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
261
+ }
262
+ self.num_labels = 9
263
+ self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
264
+ self.id_to_aphasia_type = {v: k for k, v in self.aphasia_types_mapping.items()}
265
+
266
+ def load_model(self):
267
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, use_fast=True)
268
+ # pad token fix
269
+ if self.tokenizer.pad_token is None:
270
+ if self.tokenizer.eos_token is not None:
271
+ self.tokenizer.pad_token = self.tokenizer.eos_token
272
+ elif self.tokenizer.unk_token is not None:
273
+ self.tokenizer.pad_token = self.tokenizer.unk_token
274
+ else:
275
+ self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
276
+ # optional added tokens
277
+ add_path = os.path.join(self.model_dir, "added_tokens.json")
278
+ if os.path.exists(add_path):
279
+ with open(add_path, "r", encoding="utf-8") as f:
280
+ data = json.load(f)
281
+ tokens = list(data.keys()) if isinstance(data, dict) else data
282
+ if tokens:
283
+ self.tokenizer.add_tokens(tokens)
284
+
285
+ self.config = ModelConfig()
286
+ self.config.model_name = self.model_name
287
+
288
+ self.model = StableAphasiaClassifier(self.config, self.num_labels)
289
+ self.model.bert.resize_token_embeddings(len(self.tokenizer))
290
+
291
+ model_path = os.path.join(self.model_dir, "pytorch_model.bin")
292
+ if not os.path.exists(model_path):
293
+ raise FileNotFoundError(f"模型權重文件不存在: {model_path}")
294
+ state = torch.load(model_path, map_location=self.device)
295
+ self.model.load_state_dict(state) # (once)
296
+
297
+ self.model.to(self.device)
298
+ self.model.eval()
299
+
300
+ # ---------- helpers ----------
301
+
302
+ def _dur_to_float(self, d) -> float:
303
+ """Robustly parse duration from various shapes:
304
+ - number
305
+ - ["word", ms]
306
+ - [start, end]
307
+ - {"dur": ms} (future-proof)
308
+ """
309
+ if isinstance(d, (int, float)):
310
+ return float(d)
311
+ if isinstance(d, list):
312
+ if len(d) == 2:
313
+ # ["word", 300] or [start, end]
314
+ a, b = d[0], d[1]
315
+ # case 1: word + ms
316
+ if isinstance(a, str) and isinstance(b, (int, float)):
317
+ return float(b)
318
+ # case 2: start, end
319
+ if isinstance(a, (int, float)) and isinstance(b, (int, float)):
320
+ return float(b) - float(a)
321
+ if isinstance(d, dict):
322
+ for k in ("dur", "duration", "ms"):
323
+ if k in d and isinstance(d[k], (int, float)):
324
+ return float(d[k])
325
+ return 0.0
326
+
327
+ def _extract_prosodic_features(self, durations, tokens):
328
+ vals = []
329
+ for d in durations:
330
+ vals.append(self._dur_to_float(d))
331
+ vals = [v for v in vals if v > 0]
332
+ if not vals:
333
+ return [0.0] * self.config.prosody_dim
334
+ features = [
335
+ float(np.mean(vals)),
336
+ float(np.std(vals)),
337
+ float(np.median(vals)),
338
+ float(len([v for v in vals if v > (np.mean(vals) * 1.5)])),
339
+ ]
340
+ while len(features) < self.config.prosody_dim:
341
+ features.append(0.0)
342
+ return features[:self.config.prosody_dim]
343
+
344
+ def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded):
345
+ # map subtoken -> original token index
346
+ subtoken_to_token = []
347
+ for idx, tok in enumerate(tokens):
348
+ subtoks = self.tokenizer.tokenize(tok)
349
+ subtoken_to_token.extend([idx] * max(1, len(subtoks)))
350
+
351
+ aligned_pos = [0] # [CLS]
352
+ aligned_grammar = [[0, 0, 0]] # [CLS]
353
+ aligned_durations = [0.0] # [CLS]
354
+
355
+ # reserve last slot for [SEP]
356
+ max_body = self.config.max_length - 2
357
+ for st_idx in range(max_body):
358
+ if st_idx < len(subtoken_to_token):
359
+ orig = subtoken_to_token[st_idx]
360
+ aligned_pos.append(pos_ids[orig] if orig < len(pos_ids) else 0)
361
+ aligned_grammar.append(grammar_ids[orig] if orig < len(grammar_ids) else [0, 0, 0])
362
+ aligned_durations.append(self._dur_to_float(durations[orig]) if orig < len(durations) else 0.0)
363
+ else:
364
+ aligned_pos.append(0)
365
+ aligned_grammar.append([0, 0, 0])
366
+ aligned_durations.append(0.0)
367
+
368
+ aligned_pos.append(0) # [SEP]
369
+ aligned_grammar.append([0, 0, 0]) # [SEP]
370
+ aligned_durations.append(0.0) # [SEP]
371
+ return aligned_pos, aligned_grammar, aligned_durations
372
+
373
+ def preprocess_sentence(self, sentence_data: dict) -> Optional[dict]:
374
+ all_tokens, all_pos, all_grammar, all_durations = [], [], [], []
375
+ for d_idx, dialogue in enumerate(sentence_data.get("dialogues", [])):
376
+ if d_idx > 0:
377
+ all_tokens.append("[DIALOGUE]")
378
+ all_pos.append(0)
379
+ all_grammar.append([0, 0, 0])
380
+ all_durations.append(0.0)
381
+ for par in dialogue.get("PAR", []):
382
+ if "tokens" in par and par["tokens"]:
383
+ toks = par["tokens"]
384
+ pos_ids = par.get("word_pos_ids", [0] * len(toks))
385
+ gra_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(toks))
386
+ durs = par.get("word_durations", [0.0] * len(toks))
387
+ all_tokens.extend(toks)
388
+ all_pos.extend(pos_ids)
389
+ all_grammar.extend(gra_ids)
390
+ all_durations.extend(durs)
391
+ if not all_tokens:
392
+ return None
393
+
394
+ text = " ".join(all_tokens)
395
+ enc = self.tokenizer(text, max_length=self.config.max_length, padding="max_length",
396
+ truncation=True, return_tensors="pt")
397
+ aligned_pos, aligned_gra, aligned_dur = self._align_features(
398
+ all_tokens, all_pos, all_grammar, all_durations, enc
399
+ )
400
+ prosody = self._extract_prosodic_features(all_durations, all_tokens)
401
+ prosody_tensor = torch.tensor(prosody).unsqueeze(0).repeat(self.config.max_length, 1)
402
+
403
+ return {
404
+ "input_ids": enc["input_ids"].squeeze(0),
405
+ "attention_mask": enc["attention_mask"].squeeze(0),
406
+ "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long),
407
+ "word_grammar_ids": torch.tensor(aligned_gra, dtype=torch.long),
408
+ "word_durations": torch.tensor(aligned_dur, dtype=torch.float),
409
+ "prosody_features": prosody_tensor.float(),
410
+ "sentence_id": sentence_data.get("sentence_id", "unknown"),
411
+ "original_tokens": all_tokens,
412
+ "text": text
413
+ }
414
+
415
+ def predict_single(self, sentence_data: dict) -> dict:
416
+ proc = self.preprocess_sentence(sentence_data)
417
+ if proc is None:
418
+ return {"error": "無法處理輸入數據", "sentence_id": sentence_data.get("sentence_id", "unknown")}
419
+ inp = {
420
+ "input_ids": proc["input_ids"].unsqueeze(0).to(self.device),
421
+ "attention_mask": proc["attention_mask"].unsqueeze(0).to(self.device),
422
+ "word_pos_ids": proc["word_pos_ids"].unsqueeze(0).to(self.device),
423
+ "word_grammar_ids": proc["word_grammar_ids"].unsqueeze(0).to(self.device),
424
+ "word_durations": proc["word_durations"].unsqueeze(0).to(self.device),
425
+ "prosody_features": proc["prosody_features"].unsqueeze(0).to(self.device),
426
+ }
427
+ with torch.no_grad():
428
+ out = self.model(**inp)
429
+ logits = out["logits"]
430
+ probs = F.softmax(logits, dim=1).cpu().numpy()[0]
431
+ pred_id = int(np.argmax(probs))
432
+ sev = out["severity_pred"].cpu().numpy()[0]
433
+ flu = float(out["fluency_pred"].cpu().numpy()[0][0])
434
+
435
+ pred_type = self.id_to_aphasia_type[pred_id]
436
+ conf = float(probs[pred_id])
437
+
438
+ dist = {}
439
+ for a_type, t_id in self.aphasia_types_mapping.items():
440
+ dist[a_type] = {"probability": float(probs[t_id]), "percentage": f"{probs[t_id]*100:.2f}%"}
441
+
442
+ sorted_dist = dict(sorted(dist.items(), key=lambda x: x[1]["probability"], reverse=True))
443
+ return {
444
+ "sentence_id": proc["sentence_id"],
445
+ "input_text": proc["text"],
446
+ "original_tokens": proc["original_tokens"],
447
+ "prediction": {
448
+ "predicted_class": pred_type,
449
+ "confidence": conf,
450
+ "confidence_percentage": f"{conf*100:.2f}%"
451
+ },
452
+ "class_description": self.aphasia_descriptions.get(pred_type, {
453
+ "name": pred_type, "description": "Description not available", "features": []
454
+ }),
455
+ "probability_distribution": sorted_dist,
456
+ "additional_predictions": {
457
+ "severity_distribution": {
458
+ "level_0": float(sev[0]), "level_1": float(sev[1]),
459
+ "level_2": float(sev[2]), "level_3": float(sev[3])
460
+ },
461
+ "predicted_severity_level": int(np.argmax(sev)),
462
+ "fluency_score": flu,
463
+ "fluency_rating": "High" if flu > 0.7 else ("Medium" if flu > 0.4 else "Low"),
464
+ }
465
+ }
466
+
467
+ def predict_batch(self, input_file: str, output_file: Optional[str] = None) -> Dict:
468
+ with open(input_file, "r", encoding="utf-8") as f:
469
+ data = json.load(f)
470
+ sentences = data.get("sentences", [])
471
+ results = []
472
+ print(f"開始處理 {len(sentences)} 個句子...")
473
+ for i, s in enumerate(sentences):
474
+ print(f"處理第 {i+1}/{len(sentences)} 個句子...")
475
+ results.append(self.predict_single(s))
476
+ summary = self._generate_summary(results)
477
+ final = {"summary": summary, "total_sentences": len(results), "predictions": results}
478
+ if output_file:
479
+ with open(output_file, "w", encoding="utf-8") as f:
480
+ json.dump(final, f, ensure_ascii=False, indent=2)
481
+ print(f"結果已保存到: {output_file}")
482
+ return final
483
+
484
+ def _generate_summary(self, results: List[dict]) -> dict:
485
+ if not results:
486
+ return {}
487
+ class_counts = defaultdict(int)
488
+ confs, flus = [], []
489
+ sev_counts = defaultdict(int)
490
+ for r in results:
491
+ if "error" in r:
492
+ continue
493
+ c = r["prediction"]["predicted_class"]
494
+ class_counts[c] += 1
495
+ confs.append(r["prediction"]["confidence"])
496
+ flus.append(r["additional_predictions"]["fluency_score"])
497
+ sev_counts[r["additional_predictions"]["predicted_severity_level"]] += 1
498
+ avg_conf = float(np.mean(confs)) if confs else 0.0
499
+ avg_flu = float(np.mean(flus)) if flus else 0.0
500
+ return {
501
+ "classification_distribution": dict(class_counts),
502
+ "classification_percentages": {k: f"{v/len(results)*100:.1f}%" for k, v in class_counts.items()},
503
+ "average_confidence": f"{avg_conf:.3f}",
504
+ "average_fluency_score": f"{avg_flu:.3f}",
505
+ "severity_distribution": dict(sev_counts),
506
+ "confidence_statistics": {} if not confs else {
507
+ "mean": f"{np.mean(confs):.3f}",
508
+ "std": f"{np.std(confs):.3f}",
509
+ "min": f"{np.min(confs):.3f}",
510
+ "max": f"{np.max(confs):.3f}",
511
+ },
512
+ "most_common_prediction": max(class_counts.items(), key=lambda x: x[1])[0] if class_counts else "None",
513
+ }
514
+
515
+ def generate_detailed_report(self, results: List[dict], output_dir: str = "./inference_results"):
516
+ os.makedirs(output_dir, exist_ok=True)
517
+ rows = []
518
+ for r in results:
519
+ if "error" in r:
520
+ continue
521
+ row = {
522
+ "sentence_id": r["sentence_id"],
523
+ "predicted_class": r["prediction"]["predicted_class"],
524
+ "confidence": r["prediction"]["confidence"],
525
+ "class_name": r["class_description"]["name"],
526
+ "severity_level": r["additional_predictions"]["predicted_severity_level"],
527
+ "fluency_score": r["additional_predictions"]["fluency_score"],
528
+ "fluency_rating": r["additional_predictions"]["fluency_rating"],
529
+ "input_text": r["input_text"],
530
+ }
531
+ for a_type, info in r["probability_distribution"].items():
532
+ row[f"prob_{a_type}"] = info["probability"]
533
+ rows.append(row)
534
+ if not rows:
535
+ return None
536
+ df = pd.DataFrame(rows)
537
+ df.to_csv(os.path.join(output_dir, "detailed_predictions.csv"), index=False, encoding="utf-8")
538
+ summary_stats = {
539
+ "total_predictions": int(len(rows)),
540
+ "class_distribution": df["predicted_class"].value_counts().to_dict(),
541
+ "average_confidence": float(df["confidence"].mean()),
542
+ "confidence_std": float(df["confidence"].std()),
543
+ "average_fluency": float(df["fluency_score"].mean()),
544
+ "fluency_std": float(df["fluency_score"].std()),
545
+ "severity_distribution": df["severity_level"].value_counts().to_dict(),
546
+ }
547
+ with open(os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8") as f:
548
+ json.dump(summary_stats, f, ensure_ascii=False, indent=2)
549
+ print(f"詳細報告已生成並保存到: {output_dir}")
550
+ return df
551
+
552
+
553
+ # =========================
554
+ # Convenience: run directly or from pipeline
555
+ # =========================
556
+
557
+ def predict_from_chajson(model_dir: str, chajson_path: str, output_file: Optional[str] = None) -> Dict:
558
+ """
559
+ Convenience entry:
560
+ - Accepts the JSON produced by cha_json.py
561
+ - If it contains 'sentences', runs per-sentence like before
562
+ - If it only contains 'text_all', creates a single pseudo-sentence
563
+ """
564
+ with open(chajson_path, "r", encoding="utf-8") as f:
565
+ data = json.load(f)
566
+
567
+ inf = AphasiaInferenceSystem(model_dir)
568
+
569
+ # If there are sentences, use the full path
570
+ if data.get("sentences"):
571
+ return inf.predict_batch(chajson_path, output_file=output_file)
572
+
573
+ # Else, fall back to a single synthetic sentence using text_all
574
+ text_all = data.get("text_all", "")
575
+ fake = {
576
+ "sentences": [{
577
+ "sentence_id": "S1",
578
+ "dialogues": [{
579
+ "INV": [],
580
+ "PAR": [{"tokens": text_all.split(),
581
+ "word_pos_ids": [0]*len(text_all.split()),
582
+ "word_grammar_ids": [[0,0,0]]*len(text_all.split()),
583
+ "word_durations": [0.0]*len(text_all.split())}]
584
+ }]
585
+ }]
586
+ }
587
+ tmp_path = chajson_path + "._synthetic.json"
588
+ with open(tmp_path, "w", encoding="utf-8") as f:
589
+ json.dump(fake, f, ensure_ascii=False, indent=2)
590
+ out = inf.predict_batch(tmp_path, output_file=output_file)
591
+ try:
592
+ os.remove(tmp_path)
593
+ except Exception:
594
+ pass
595
+ return out
596
+
597
+
598
+ # ---------- CLI ----------
599
+
600
+ def main():
601
+ import argparse
602
+ p = argparse.ArgumentParser(description="失語症分類推理系統")
603
+ p.add_argument("--model_dir", type=str, required=False, default="./adaptive_aphasia_model",
604
+ help="訓練好的模型目錄路徑")
605
+ p.add_argument("--input_file", type=str, required=True,
606
+ help="輸入JSON文件(cha_json 的輸出)")
607
+ p.add_argument("--output_file", type=str, default="./aphasia_predictions.json",
608
+ help="輸出JSON文件路徑")
609
+ p.add_argument("--report_dir", type=str, default="./inference_results",
610
+ help="詳細報告輸出目錄")
611
+ p.add_argument("--generate_report", action="store_true",
612
+ help="是否生成詳細的CSV報告")
613
+ args = p.parse_args()
614
+
615
+ try:
616
+ print("正在初始化推理系統...")
617
+ sys = AphasiaInferenceSystem(args.model_dir)
618
+
619
+ print("開始執行批次預測...")
620
+ results = sys.predict_batch(args.input_file, args.output_file)
621
+
622
+ if args.generate_report:
623
+ print("生成詳細報告...")
624
+ sys.generate_detailed_report(results["predictions"], args.report_dir)
625
+
626
+ print("\n=== 預測摘要 ===")
627
+ s = results["summary"]
628
+ print(f"總句子數: {results['total_sentences']}")
629
+ print(f"平均信心度: {s.get('average_confidence', 'N/A')}")
630
+ print(f"平均流利度: {s.get('average_fluency_score', 'N/A')}")
631
+ print(f"最常見預測: {s.get('most_common_prediction', 'N/A')}")
632
+ print("\n類別分佈:")
633
+ for name, count in s.get("classification_distribution", {}).items():
634
+ pct = s.get("classification_percentages", {}).get(name, "0%")
635
+ print(f" {name}: {count} ({pct})")
636
+ print(f"\n結果已保存到: {args.output_file}")
637
+ except Exception as e:
638
+ print(f"錯誤: {str(e)}")
639
+ import traceback; traceback.print_exc()
640
+
641
+ if __name__ == "__main__":
642
+ main()