Ellie5757575757 commited on
Commit
f84d4de
Β·
verified Β·
1 Parent(s): dcfcfa9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +410 -367
app.py CHANGED
@@ -1,437 +1,480 @@
1
  import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
  import json
6
- import numpy as np
7
- import math
8
- from transformers import AutoTokenizer, AutoModel
9
- from typing import Dict, List, Optional
10
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- # Recreate the model classes (simplified versions)
17
- class StablePositionalEncoding(nn.Module):
18
- def __init__(self, d_model: int, max_len: int = 5000):
19
- super().__init__()
20
- self.d_model = d_model
21
-
22
- pe = torch.zeros(max_len, d_model)
23
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
24
- div_term = torch.exp(torch.arange(0, d_model, 2).float() *
25
- (-math.log(10000.0) / d_model))
26
-
27
- pe[:, 0::2] = torch.sin(position * div_term)
28
- pe[:, 1::2] = torch.cos(position * div_term)
29
-
30
- self.register_buffer('pe', pe.unsqueeze(0))
31
- self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
32
-
33
- def forward(self, x):
34
- seq_len = x.size(1)
35
- sinusoidal = self.pe[:, :seq_len, :].to(x.device)
36
- learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
37
- return x + 0.1 * (sinusoidal + learnable)
38
 
39
- class StableMultiHeadAttention(nn.Module):
40
- def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
41
- super().__init__()
42
- self.num_heads = num_heads
43
- self.feature_dim = feature_dim
44
- self.head_dim = feature_dim // num_heads
45
-
46
- assert feature_dim % num_heads == 0
47
-
48
- self.query = nn.Linear(feature_dim, feature_dim)
49
- self.key = nn.Linear(feature_dim, feature_dim)
50
- self.value = nn.Linear(feature_dim, feature_dim)
51
- self.dropout = nn.Dropout(dropout)
52
- self.output_proj = nn.Linear(feature_dim, feature_dim)
53
- self.layer_norm = nn.LayerNorm(feature_dim)
54
 
55
- def forward(self, x, mask=None):
56
- batch_size, seq_len, _ = x.size()
 
 
57
 
58
- Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
59
- K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
60
- V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
 
61
 
62
- scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
 
 
 
63
 
64
- if mask is not None:
65
- if mask.dim() == 2:
66
- mask = mask.unsqueeze(1).unsqueeze(1)
67
- scores.masked_fill_(mask == 0, -1e9)
68
 
69
- attn_weights = F.softmax(scores, dim=-1)
70
- attn_weights = self.dropout(attn_weights)
 
 
 
 
 
71
 
72
- context = torch.matmul(attn_weights, V)
73
- context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)
 
 
 
74
 
75
- output = self.output_proj(context)
76
- return self.layer_norm(output + x)
 
 
 
 
 
 
77
 
78
- class StableLinguisticFeatureExtractor(nn.Module):
79
- def __init__(self, config):
80
- super().__init__()
81
- self.config = config
82
-
83
- # Simplified version - just return zeros if features not available
84
- self.pos_embedding = nn.Embedding(config.get('pos_vocab_size', 150), config.get('pos_emb_dim', 64), padding_idx=0)
85
- self.pos_attention = StableMultiHeadAttention(config.get('pos_emb_dim', 64), num_heads=4)
86
-
87
- self.grammar_projection = nn.Sequential(
88
- nn.Linear(config.get('grammar_dim', 3), config.get('grammar_hidden_dim', 64)),
89
- nn.Tanh(),
90
- nn.LayerNorm(config.get('grammar_hidden_dim', 64)),
91
- nn.Dropout(config.get('dropout_rate', 0.3) * 0.3)
92
- )
93
-
94
- self.duration_projection = nn.Sequential(
95
- nn.Linear(1, config.get('duration_hidden_dim', 128)),
96
- nn.Tanh(),
97
- nn.LayerNorm(config.get('duration_hidden_dim', 128))
98
- )
99
-
100
- self.prosody_projection = nn.Sequential(
101
- nn.Linear(config.get('prosody_dim', 32), config.get('prosody_dim', 32)),
102
- nn.ReLU(),
103
- nn.LayerNorm(config.get('prosody_dim', 32))
104
- )
105
-
106
- total_feature_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) +
107
- config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32))
108
- self.feature_fusion = nn.Sequential(
109
- nn.Linear(total_feature_dim, total_feature_dim // 2),
110
- nn.Tanh(),
111
- nn.LayerNorm(total_feature_dim // 2),
112
- nn.Dropout(config.get('dropout_rate', 0.3))
113
- )
114
-
115
- def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
116
- batch_size, seq_len = pos_ids.size()
117
-
118
- # Simple processing - can be expanded later
119
- pos_ids_clamped = pos_ids.clamp(0, self.config.get('pos_vocab_size', 150) - 1)
120
- pos_embeds = self.pos_embedding(pos_ids_clamped)
121
- pos_features = self.pos_attention(pos_embeds, attention_mask)
122
-
123
- grammar_features = self.grammar_projection(grammar_ids.float())
124
- duration_features = self.duration_projection(durations.unsqueeze(-1).float())
125
- prosody_features = self.prosody_projection(prosody_features.float())
126
 
127
- combined_features = torch.cat([
128
- pos_features, grammar_features, duration_features, prosody_features
129
- ], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- fused_features = self.feature_fusion(combined_features)
 
132
 
133
- mask_expanded = attention_mask.unsqueeze(-1).float()
134
- pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1)
 
 
 
 
 
 
135
 
136
- return pooled_features
137
-
138
- class StableAphasiaClassifier(nn.Module):
139
- def __init__(self, config, num_labels: int):
140
- super().__init__()
141
- self.config = config
142
- self.num_labels = num_labels
143
 
144
- try:
145
- # Load the base BERT model
146
- self.bert = AutoModel.from_pretrained(config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'))
147
- self.bert_config = self.bert.config
148
 
149
- # Freeze embeddings for stability
150
- for param in self.bert.embeddings.parameters():
151
- param.requires_grad = False
 
 
 
 
 
152
 
153
- self.positional_encoder = StablePositionalEncoding(
154
- d_model=self.bert_config.hidden_size,
155
- max_len=config.get('max_length', 512)
156
- )
157
 
158
- self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
 
 
159
 
160
- bert_dim = self.bert_config.hidden_size
161
- linguistic_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) +
162
- config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32)) // 2
 
 
163
 
164
- self.feature_fusion = nn.Sequential(
165
- nn.Linear(bert_dim + linguistic_dim, bert_dim),
166
- nn.LayerNorm(bert_dim),
167
- nn.Tanh(),
168
- nn.Dropout(config.get('dropout_rate', 0.3))
169
- )
170
 
171
- # Classifier
172
- self.classifier = self._build_classifier(bert_dim, num_labels, config)
 
 
 
173
 
174
- # Multi-task heads (simplified)
175
- self.severity_head = nn.Sequential(
176
- nn.Linear(bert_dim, 4),
177
- nn.Softmax(dim=-1)
178
- )
179
 
180
- self.fluency_head = nn.Sequential(
181
- nn.Linear(bert_dim, 1),
182
- nn.Sigmoid()
183
- )
184
 
185
- except Exception as e:
186
- logger.error(f"Error initializing model: {e}")
187
- raise
188
-
189
- def _build_classifier(self, input_dim: int, num_labels: int, config):
190
- layers = []
191
- current_dim = input_dim
192
-
193
- classifier_hidden_dims = config.get('classifier_hidden_dims', [512, 256])
194
- for hidden_dim in classifier_hidden_dims:
195
- layers.extend([
196
- nn.Linear(current_dim, hidden_dim),
197
- nn.LayerNorm(hidden_dim),
198
- nn.Tanh(),
199
- nn.Dropout(config.get('dropout_rate', 0.3))
200
- ])
201
- current_dim = hidden_dim
202
-
203
- layers.append(nn.Linear(current_dim, num_labels))
204
- return nn.Sequential(*layers)
205
-
206
- def forward(self, input_ids, attention_mask, labels=None,
207
- word_pos_ids=None, word_grammar_ids=None, word_durations=None,
208
- prosody_features=None, **kwargs):
209
-
210
- bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
211
- sequence_output = bert_outputs.last_hidden_state
212
-
213
- position_enhanced = self.positional_encoder(sequence_output)
214
- pooled_output = self._attention_pooling(position_enhanced, attention_mask)
215
-
216
- # Handle missing linguistic features
217
- if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
218
- if prosody_features is None:
219
- batch_size, seq_len = input_ids.size()
220
- prosody_features = torch.zeros(
221
- batch_size, seq_len, self.config.get('prosody_dim', 32),
222
- device=input_ids.device
223
- )
224
 
225
- linguistic_features = self.linguistic_extractor(
226
- word_pos_ids, word_grammar_ids, word_durations,
227
- prosody_features, attention_mask
 
 
 
228
  )
 
229
  else:
230
- # Create dummy linguistic features
231
- linguistic_features = torch.zeros(
232
- input_ids.size(0),
233
- (self.config.get('pos_emb_dim', 64) + self.config.get('grammar_hidden_dim', 64) +
234
- self.config.get('duration_hidden_dim', 128) + self.config.get('prosody_dim', 32)) // 2,
235
- device=input_ids.device
236
  )
237
-
238
- combined_features = torch.cat([pooled_output, linguistic_features], dim=1)
239
- fused_features = self.feature_fusion(combined_features)
240
-
241
- logits = self.classifier(fused_features)
242
- severity_pred = self.severity_head(fused_features)
243
- fluency_pred = self.fluency_head(fused_features)
244
-
245
- return {
246
- "logits": logits,
247
- "severity_pred": severity_pred,
248
- "fluency_pred": fluency_pred,
249
- }
250
-
251
- def _attention_pooling(self, sequence_output, attention_mask):
252
- attention_weights = torch.softmax(
253
- torch.sum(sequence_output, dim=-1, keepdim=True), dim=1
254
  )
255
- attention_weights = attention_weights * attention_mask.unsqueeze(-1).float()
256
- attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9)
257
- pooled = torch.sum(sequence_output * attention_weights, dim=1)
258
- return pooled
259
 
260
- # Load configuration and model
261
- def load_model():
 
 
262
  try:
263
- # Load configuration
264
- with open("config.json", "r") as f:
265
- config = json.load(f)
266
-
267
- logger.info(f"Loaded config: {config}")
268
-
269
- # Initialize tokenizer
270
- model_name = config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
271
- tokenizer = AutoTokenizer.from_pretrained(model_name)
272
 
273
- # Add special tokens
274
- special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"]
275
- tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- # Initialize model
278
- num_labels = config.get('num_labels', 9)
279
- model_config = config.get('model_config', {})
 
280
 
281
- model = StableAphasiaClassifier(model_config, num_labels)
282
- model.bert.resize_token_embeddings(len(tokenizer))
283
 
284
- # Load model weights
285
  try:
286
- state_dict = torch.load("pytorch_model.bin", map_location="cpu")
287
- model.load_state_dict(state_dict)
288
- logger.info("Successfully loaded model weights")
289
- except Exception as e:
290
- logger.error(f"Error loading model weights: {e}")
291
- logger.info("Using randomly initialized weights")
292
-
293
- model.eval()
294
-
295
- # Get label mapping
296
- id2label = config.get('id2label', {})
297
-
298
- return model, tokenizer, id2label
299
-
300
- except Exception as e:
301
- logger.error(f"Error loading model: {e}")
302
- raise
303
-
304
- # Initialize model (with error handling)
305
- try:
306
- model, tokenizer, id2label = load_model()
307
- logger.info("Model loaded successfully!")
308
- except Exception as e:
309
- logger.error(f"Failed to load model: {e}")
310
- model, tokenizer, id2label = None, None, {}
311
-
312
- def predict_aphasia(text):
313
- """Predict aphasia type from text"""
314
- try:
315
- if model is None or tokenizer is None:
316
- return "Error: Model not loaded properly. Please check the logs.", 0.0, "N/A", 0.0
317
-
318
- if not text or not text.strip():
319
- return "Please enter some text for analysis.", 0.0, "N/A", 0.0
320
-
321
- # Tokenize input
322
- inputs = tokenizer(
323
- text,
324
- max_length=512,
325
- padding="max_length",
326
- truncation=True,
327
- return_tensors="pt"
328
- )
329
-
330
- # Create dummy linguistic features (since we don't have them from raw text)
331
- batch_size, seq_len = inputs["input_ids"].size()
332
- dummy_pos = torch.zeros(batch_size, seq_len, dtype=torch.long)
333
- dummy_grammar = torch.zeros(batch_size, seq_len, 3, dtype=torch.long)
334
- dummy_durations = torch.zeros(batch_size, seq_len, dtype=torch.float)
335
- dummy_prosody = torch.zeros(batch_size, seq_len, 32, dtype=torch.float)
336
 
337
- # Make prediction
338
- with torch.no_grad():
339
- outputs = model(
340
- input_ids=inputs["input_ids"],
341
- attention_mask=inputs["attention_mask"],
342
- word_pos_ids=dummy_pos,
343
- word_grammar_ids=dummy_grammar,
344
- word_durations=dummy_durations,
345
- prosody_features=dummy_prosody
 
 
 
 
346
  )
347
-
348
- # Process outputs
349
- logits = outputs["logits"]
350
- probs = F.softmax(logits, dim=1)
351
- predicted_class_id = torch.argmax(probs, dim=1).item()
352
- confidence = torch.max(probs, dim=1)[0].item()
353
-
354
- # Get predicted label
355
- predicted_label = id2label.get(str(predicted_class_id), f"Class_{predicted_class_id}")
356
-
357
- # Get additional predictions
358
- severity = torch.argmax(outputs["severity_pred"], dim=1).item()
359
- fluency = outputs["fluency_pred"].item()
360
-
361
- # Format results
362
- result = f"Predicted Aphasia Type: {predicted_label}"
363
- confidence_str = f"Confidence: {confidence:.2%}"
364
- severity_str = f"Severity Level: {severity}/3"
365
- fluency_str = f"Fluency Score: {fluency:.3f}"
366
-
367
- return result, confidence, severity_str, fluency_str
368
-
369
  except Exception as e:
370
- logger.error(f"Prediction error: {e}")
371
- return f"Error during prediction: {str(e)}", 0.0, "N/A", 0.0
 
 
 
 
 
 
372
 
373
  # Create Gradio interface
374
  def create_interface():
375
- """Create Gradio interface"""
376
 
377
- with gr.Blocks(title="Aphasia Classification System") as demo:
378
- gr.Markdown("# 🧠 Advanced Aphasia Classification System")
379
- gr.Markdown("Enter speech or text data to classify aphasia type and analyze linguistic patterns.")
380
-
381
- with gr.Row():
382
- with gr.Column():
383
- text_input = gr.Textbox(
384
- label="Input Text",
385
- placeholder="Enter speech transcription or text for analysis...",
386
- lines=5
387
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
- submit_btn = gr.Button("Analyze Text", variant="primary")
390
- clear_btn = gr.Button("Clear", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
- with gr.Column():
393
- prediction_output = gr.Textbox(label="Prediction Result", lines=2)
394
- confidence_output = gr.Textbox(label="Confidence Score", lines=1)
395
- severity_output = gr.Textbox(label="Severity Analysis", lines=1)
396
- fluency_output = gr.Textbox(label="Fluency Analysis", lines=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  # Event handlers
399
- submit_btn.click(
400
- fn=predict_aphasia,
401
- inputs=[text_input],
402
- outputs=[prediction_output, confidence_output, severity_output, fluency_output]
 
 
 
 
 
 
403
  )
404
 
405
- clear_btn.click(
406
- fn=lambda: ("", "", "", "", ""),
407
- inputs=[],
408
- outputs=[text_input, prediction_output, confidence_output, severity_output, fluency_output]
409
- )
410
-
411
- # Add examples
412
- gr.Examples(
413
- examples=[
414
- ["The patient... uh... wants to... go home but... cannot... find the words"],
415
- ["Woman is... is washing dishes and the... the... sink is overflowing with water everywhere"],
416
- ["Cookie is in the cookie jar on the... on the... what do you call it... the shelf thing"]
417
- ],
418
- inputs=text_input
419
  )
420
 
421
- gr.Markdown("### About")
422
- gr.Markdown("This system uses a specialized transformer model trained on clinical speech data to classify different types of aphasia.")
 
 
 
 
 
423
 
424
  return demo
425
 
426
- # Launch the app
427
  if __name__ == "__main__":
428
  try:
 
 
 
 
 
 
 
 
 
429
  demo = create_interface()
430
  demo.launch(
431
  server_name="0.0.0.0",
432
  server_port=7860,
433
- show_error=True
 
434
  )
 
435
  except Exception as e:
436
  logger.error(f"Failed to launch app: {e}")
437
- print(f"Application startup failed: {e}")
 
 
1
  import gradio as gr
 
 
 
2
  import json
3
+ import os
4
+ import tempfile
 
 
5
  import logging
6
+ import traceback
7
+ from pathlib import Path
8
+
9
+ # Import your pipeline modules
10
+ try:
11
+ from utils_audio import convert_to_wav
12
+ from to_cha import to_cha_from_wav
13
+ from cha_json import cha_to_json_file
14
+ from output import predict_from_chajson
15
+ except ImportError as e:
16
+ logging.error(f"Import error: {e}")
17
+ # Fallback imports or error handling
18
 
19
  # Set up logging
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
+ # Configuration
24
+ MODEL_DIR = "./adaptive_aphasia_model" # Path to your trained model
25
+ SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def run_complete_pipeline(audio_file_path: str) -> dict:
28
+ """
29
+ Complete pipeline: Audio β†’ WAV β†’ CHA β†’ JSON β†’ Model Prediction
30
+ """
31
+ try:
32
+ logger.info(f"Starting pipeline for: {audio_file_path}")
 
 
 
 
 
 
 
 
 
33
 
34
+ # Step 1: Convert to WAV
35
+ logger.info("Step 1: Converting audio to WAV...")
36
+ wav_path = convert_to_wav(audio_file_path, sr=16000, mono=True)
37
+ logger.info(f"WAV conversion completed: {wav_path}")
38
 
39
+ # Step 2: Generate CHA file using Batchalign
40
+ logger.info("Step 2: Generating CHA file...")
41
+ cha_path = to_cha_from_wav(wav_path, lang="eng")
42
+ logger.info(f"CHA generation completed: {cha_path}")
43
 
44
+ # Step 3: Convert CHA to JSON
45
+ logger.info("Step 3: Converting CHA to JSON...")
46
+ chajson_path, json_data = cha_to_json_file(cha_path)
47
+ logger.info(f"JSON conversion completed: {chajson_path}")
48
 
49
+ # Step 4: Run aphasia classification
50
+ logger.info("Step 4: Running aphasia classification...")
51
+ results = predict_from_chajson(MODEL_DIR, chajson_path, output_file=None)
52
+ logger.info("Classification completed")
53
 
54
+ # Cleanup temporary files
55
+ try:
56
+ os.unlink(wav_path)
57
+ os.unlink(cha_path)
58
+ os.unlink(chajson_path)
59
+ except Exception as cleanup_error:
60
+ logger.warning(f"Cleanup error: {cleanup_error}")
61
 
62
+ return {
63
+ "success": True,
64
+ "results": results,
65
+ "message": "Pipeline completed successfully"
66
+ }
67
 
68
+ except Exception as e:
69
+ logger.error(f"Pipeline error: {str(e)}")
70
+ logger.error(traceback.format_exc())
71
+ return {
72
+ "success": False,
73
+ "error": str(e),
74
+ "message": f"Pipeline failed: {str(e)}"
75
+ }
76
 
77
+ def process_audio_input(audio_file):
78
+ """
79
+ Process audio file and return formatted results
80
+ """
81
+ try:
82
+ if audio_file is None:
83
+ return (
84
+ "❌ Error: No audio file uploaded",
85
+ "",
86
+ "",
87
+ "",
88
+ ""
89
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Check file format
92
+ file_path = audio_file
93
+ if isinstance(audio_file, str):
94
+ file_path = audio_file
95
+ else:
96
+ # Handle Gradio file object
97
+ file_path = audio_file.name if hasattr(audio_file, 'name') else str(audio_file)
98
+
99
+ file_ext = Path(file_path).suffix.lower()
100
+ if file_ext not in SUPPORTED_AUDIO_FORMATS:
101
+ return (
102
+ f"❌ Error: Unsupported file format {file_ext}",
103
+ f"Supported formats: {', '.join(SUPPORTED_AUDIO_FORMATS)}",
104
+ "",
105
+ "",
106
+ ""
107
+ )
108
 
109
+ # Run the complete pipeline
110
+ pipeline_result = run_complete_pipeline(file_path)
111
 
112
+ if not pipeline_result["success"]:
113
+ return (
114
+ f"❌ Pipeline Error: {pipeline_result['message']}",
115
+ pipeline_result.get('error', ''),
116
+ "",
117
+ "",
118
+ ""
119
+ )
120
 
121
+ # Extract results
122
+ results = pipeline_result["results"]
 
 
 
 
 
123
 
124
+ # Format main prediction
125
+ if "predictions" in results and len(results["predictions"]) > 0:
126
+ first_pred = results["predictions"][0]
 
127
 
128
+ if "error" in first_pred:
129
+ return (
130
+ f"❌ Classification Error: {first_pred['error']}",
131
+ "",
132
+ "",
133
+ "",
134
+ ""
135
+ )
136
 
137
+ # Main prediction
138
+ predicted_class = first_pred["prediction"]["predicted_class"]
139
+ confidence = first_pred["prediction"]["confidence_percentage"]
140
+ class_description = first_pred["class_description"]["name"]
141
 
142
+ main_result = f"🧠 **Predicted Aphasia Type:** {predicted_class}\n"
143
+ main_result += f"πŸ“Š **Confidence:** {confidence}\n"
144
+ main_result += f"πŸ“‹ **Description:** {class_description}"
145
 
146
+ # Detailed analysis
147
+ features = first_pred["class_description"].get("features", [])
148
+ detailed_analysis = f"**Key Features:**\n"
149
+ for feature in features:
150
+ detailed_analysis += f"β€’ {feature}\n"
151
 
152
+ detailed_analysis += f"\n**Clinical Description:**\n"
153
+ detailed_analysis += first_pred["class_description"].get("description", "No description available")
 
 
 
 
154
 
155
+ # Additional metrics
156
+ additional_info = first_pred["additional_predictions"]
157
+ severity_level = additional_info["predicted_severity_level"]
158
+ fluency_score = additional_info["fluency_score"]
159
+ fluency_rating = additional_info["fluency_rating"]
160
 
161
+ additional_metrics = f"**Severity Level:** {severity_level}/3\n"
162
+ additional_metrics += f"**Fluency Score:** {fluency_score:.3f} ({fluency_rating})\n"
 
 
 
163
 
164
+ # Probability distribution (top 3)
165
+ prob_dist = first_pred["probability_distribution"]
166
+ top_3 = list(prob_dist.items())[:3]
 
167
 
168
+ probability_breakdown = "**Top 3 Classifications:**\n"
169
+ for i, (aphasia_type, info) in enumerate(top_3, 1):
170
+ probability_breakdown += f"{i}. {aphasia_type}: {info['percentage']}\n"
171
+
172
+ # Summary statistics
173
+ summary = results.get("summary", {})
174
+ summary_text = f"**Processing Summary:**\n"
175
+ summary_text += f"β€’ Total sentences analyzed: {results.get('total_sentences', 'N/A')}\n"
176
+ summary_text += f"β€’ Average confidence: {summary.get('average_confidence', 'N/A')}\n"
177
+ summary_text += f"β€’ Average fluency: {summary.get('average_fluency_score', 'N/A')}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ return (
180
+ main_result,
181
+ detailed_analysis,
182
+ additional_metrics,
183
+ probability_breakdown,
184
+ summary_text
185
  )
186
+
187
  else:
188
+ return (
189
+ "❌ No predictions generated",
190
+ "The audio file may not contain analyzable speech",
191
+ "",
192
+ "",
193
+ ""
194
  )
195
+
196
+ except Exception as e:
197
+ logger.error(f"Processing error: {str(e)}")
198
+ logger.error(traceback.format_exc())
199
+ return (
200
+ f"❌ Processing Error: {str(e)}",
201
+ "Please check the logs for more details",
202
+ "",
203
+ "",
204
+ ""
 
 
 
 
 
 
 
205
  )
 
 
 
 
206
 
207
+ def process_text_input(text_input):
208
+ """
209
+ Process text input directly (fallback option)
210
+ """
211
  try:
212
+ if not text_input or not text_input.strip():
213
+ return (
214
+ "❌ Error: Please enter some text for analysis",
215
+ "",
216
+ "",
217
+ "",
218
+ ""
219
+ )
 
220
 
221
+ # Create a simple JSON structure for text-only input
222
+ temp_json = {
223
+ "sentences": [{
224
+ "sentence_id": "S1",
225
+ "aphasia_type": "UNKNOWN",
226
+ "dialogues": [{
227
+ "INV": [],
228
+ "PAR": [{
229
+ "tokens": text_input.split(),
230
+ "word_pos_ids": [0] * len(text_input.split()),
231
+ "word_grammar_ids": [[0, 0, 0]] * len(text_input.split()),
232
+ "word_durations": [0.0] * len(text_input.split()),
233
+ "utterance_text": text_input
234
+ }]
235
+ }]
236
+ }],
237
+ "text_all": text_input
238
+ }
239
 
240
+ # Save to temporary file
241
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
242
+ json.dump(temp_json, f, ensure_ascii=False, indent=2)
243
+ temp_json_path = f.name
244
 
245
+ # Run prediction
246
+ results = predict_from_chajson(MODEL_DIR, temp_json_path, output_file=None)
247
 
248
+ # Cleanup
249
  try:
250
+ os.unlink(temp_json_path)
251
+ except:
252
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ # Format results (similar to audio processing)
255
+ if "predictions" in results and len(results["predictions"]) > 0:
256
+ first_pred = results["predictions"][0]
257
+
258
+ predicted_class = first_pred["prediction"]["predicted_class"]
259
+ confidence = first_pred["prediction"]["confidence_percentage"]
260
+
261
+ return (
262
+ f"🧠 **Predicted:** {predicted_class} ({confidence})",
263
+ first_pred["class_description"]["description"],
264
+ f"Severity: {first_pred['additional_predictions']['predicted_severity_level']}/3",
265
+ f"Fluency: {first_pred['additional_predictions']['fluency_rating']}",
266
+ "Text-based analysis completed"
267
  )
268
+ else:
269
+ return (
270
+ "❌ No predictions generated",
271
+ "",
272
+ "",
273
+ "",
274
+ ""
275
+ )
276
+
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  except Exception as e:
278
+ logger.error(f"Text processing error: {str(e)}")
279
+ return (
280
+ f"❌ Error: {str(e)}",
281
+ "",
282
+ "",
283
+ "",
284
+ ""
285
+ )
286
 
287
  # Create Gradio interface
288
  def create_interface():
289
+ """Create the main Gradio interface"""
290
 
291
+ with gr.Blocks(
292
+ title="Advanced Aphasia Classification System",
293
+ theme=gr.themes.Soft(),
294
+ css="""
295
+ .main-header { text-align: center; margin-bottom: 2rem; }
296
+ .upload-section { border: 2px dashed #ccc; padding: 2rem; border-radius: 10px; }
297
+ .results-section { margin-top: 2rem; }
298
+ """
299
+ ) as demo:
300
+
301
+ # Header
302
+ gr.HTML("""
303
+ <div class="main-header">
304
+ <h1>🧠 Advanced Aphasia Classification System</h1>
305
+ <p>Upload audio files (MP3, MP4, WAV) or enter text to analyze speech patterns and classify aphasia types</p>
306
+ </div>
307
+ """)
308
+
309
+ with gr.Tabs():
310
+ # Audio Input Tab
311
+ with gr.TabItem("🎡 Audio Analysis", id="audio_tab"):
312
+ gr.Markdown("### Upload Audio File")
313
+ gr.Markdown("Supported formats: MP3, MP4, WAV, M4A, FLAC, OGG")
314
+
315
+ with gr.Row():
316
+ with gr.Column(scale=1):
317
+ audio_input = gr.File(
318
+ label="Upload Audio File",
319
+ file_types=["audio"],
320
+ type="filepath"
321
+ )
322
+
323
+ process_audio_btn = gr.Button(
324
+ "πŸ” Analyze Audio",
325
+ variant="primary",
326
+ size="lg"
327
+ )
328
+
329
+ gr.Markdown("**Note:** Processing may take 1-3 minutes depending on audio length")
330
 
331
+ # Results section for audio
332
+ with gr.Column(scale=2, visible=True) as audio_results:
333
+ gr.Markdown("### πŸ“Š Analysis Results")
334
+
335
+ audio_main_result = gr.Textbox(
336
+ label="🎯 Primary Classification",
337
+ lines=3,
338
+ interactive=False
339
+ )
340
+
341
+ with gr.Row():
342
+ audio_detailed = gr.Textbox(
343
+ label="πŸ“‹ Detailed Analysis",
344
+ lines=6,
345
+ interactive=False
346
+ )
347
+
348
+ audio_metrics = gr.Textbox(
349
+ label="πŸ“ˆ Additional Metrics",
350
+ lines=6,
351
+ interactive=False
352
+ )
353
+
354
+ with gr.Row():
355
+ audio_probabilities = gr.Textbox(
356
+ label="πŸ“Š Probability Breakdown",
357
+ lines=4,
358
+ interactive=False
359
+ )
360
+
361
+ audio_summary = gr.Textbox(
362
+ label="πŸ“ Processing Summary",
363
+ lines=4,
364
+ interactive=False
365
+ )
366
+
367
+ # Text Input Tab (Fallback)
368
+ with gr.TabItem("πŸ“ Text Analysis", id="text_tab"):
369
+ gr.Markdown("### Direct Text Input")
370
+ gr.Markdown("Enter speech transcription or text for analysis (fallback option)")
371
+
372
+ with gr.Row():
373
+ with gr.Column():
374
+ text_input = gr.Textbox(
375
+ label="Input Text",
376
+ placeholder="Enter speech transcription or text for analysis...",
377
+ lines=5
378
+ )
379
+
380
+ process_text_btn = gr.Button(
381
+ "πŸ” Analyze Text",
382
+ variant="secondary",
383
+ size="lg"
384
+ )
385
 
386
+ # Results section for text
387
+ with gr.Column() as text_results:
388
+ gr.Markdown("### πŸ“Š Analysis Results")
389
+
390
+ text_main_result = gr.Textbox(
391
+ label="🎯 Primary Classification",
392
+ lines=2,
393
+ interactive=False
394
+ )
395
+
396
+ with gr.Row():
397
+ text_detailed = gr.Textbox(
398
+ label="πŸ“‹ Clinical Description",
399
+ lines=4,
400
+ interactive=False
401
+ )
402
+
403
+ text_metrics = gr.Textbox(
404
+ label="πŸ“ˆ Metrics",
405
+ lines=4,
406
+ interactive=False
407
+ )
408
+
409
+ with gr.Row():
410
+ text_probabilities = gr.Textbox(
411
+ label="πŸ“Š Assessment",
412
+ lines=2,
413
+ interactive=False
414
+ )
415
+
416
+ text_summary = gr.Textbox(
417
+ label="πŸ“ Status",
418
+ lines=2,
419
+ interactive=False
420
+ )
421
 
422
  # Event handlers
423
+ process_audio_btn.click(
424
+ fn=process_audio_input,
425
+ inputs=[audio_input],
426
+ outputs=[
427
+ audio_main_result,
428
+ audio_detailed,
429
+ audio_metrics,
430
+ audio_probabilities,
431
+ audio_summary
432
+ ]
433
  )
434
 
435
+ process_text_btn.click(
436
+ fn=process_text_input,
437
+ inputs=[text_input],
438
+ outputs=[
439
+ text_main_result,
440
+ text_detailed,
441
+ text_metrics,
442
+ text_probabilities,
443
+ text_summary
444
+ ]
 
 
 
 
445
  )
446
 
447
+ # Footer
448
+ gr.HTML("""
449
+ <div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid #eee;">
450
+ <p><strong>About:</strong> This system uses advanced NLP and acoustic analysis to classify different types of aphasia from speech samples.</p>
451
+ <p><em>For research and clinical assessment purposes.</em></p>
452
+ </div>
453
+ """)
454
 
455
  return demo
456
 
457
+ # Launch the application
458
  if __name__ == "__main__":
459
  try:
460
+ logger.info("Starting Aphasia Classification System...")
461
+
462
+ # Check if model directory exists
463
+ if not os.path.exists(MODEL_DIR):
464
+ logger.error(f"Model directory not found: {MODEL_DIR}")
465
+ print(f"❌ Error: Model directory not found: {MODEL_DIR}")
466
+ print("Please ensure your trained model is in the correct directory.")
467
+
468
+ # Create and launch interface
469
  demo = create_interface()
470
  demo.launch(
471
  server_name="0.0.0.0",
472
  server_port=7860,
473
+ show_error=True,
474
+ share=False
475
  )
476
+
477
  except Exception as e:
478
  logger.error(f"Failed to launch app: {e}")
479
+ logger.error(traceback.format_exc())
480
+ print(f"❌ Application startup failed: {e}")