Ellie5757575757 commited on
Commit
e800564
·
verified ·
1 Parent(s): d73663e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +434 -18
app.py CHANGED
@@ -1,21 +1,437 @@
1
- import warnings; warnings.filterwarnings("ignore", message="pkg_resources is deprecated")
2
  import gradio as gr
3
- from pipeline import run_pipeline
4
- import gradio_client, fastapi
5
- print("VERSIONS -> gradio:", gr.__version__, "gradio_client:", gradio_client.__version__, "fastapi:", fastapi.__version__)
6
-
7
- def infer(file):
8
- path = getattr(file, "name", file) # gr.File gives a tempfile
9
- return run_pipeline(path, out_style="json") # returns JSON string
10
-
11
- demo = gr.Interface(
12
- fn=infer,
13
- inputs=gr.File(label="Upload audio/video (mp3, mp4, wav)"),
14
- outputs=gr.Textbox(label="Result (JSON text)"), # ← was gr.JSON; use Textbox
15
- title="Aphasia Classification",
16
- description="MP3/MP4 → WAV → .cha → JSON → model",
17
- concurrency_limit=1,
18
- )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if __name__ == "__main__":
21
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, max_threads=1, share=True) # ← share=True
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import json
6
+ import numpy as np
7
+ import math
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from typing import Dict, List, Optional
10
+ import logging
 
 
 
 
 
 
 
11
 
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Recreate the model classes (simplified versions)
17
+ class StablePositionalEncoding(nn.Module):
18
+ def __init__(self, d_model: int, max_len: int = 5000):
19
+ super().__init__()
20
+ self.d_model = d_model
21
+
22
+ pe = torch.zeros(max_len, d_model)
23
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
24
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
25
+ (-math.log(10000.0) / d_model))
26
+
27
+ pe[:, 0::2] = torch.sin(position * div_term)
28
+ pe[:, 1::2] = torch.cos(position * div_term)
29
+
30
+ self.register_buffer('pe', pe.unsqueeze(0))
31
+ self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
32
+
33
+ def forward(self, x):
34
+ seq_len = x.size(1)
35
+ sinusoidal = self.pe[:, :seq_len, :].to(x.device)
36
+ learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
37
+ return x + 0.1 * (sinusoidal + learnable)
38
+
39
+ class StableMultiHeadAttention(nn.Module):
40
+ def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
41
+ super().__init__()
42
+ self.num_heads = num_heads
43
+ self.feature_dim = feature_dim
44
+ self.head_dim = feature_dim // num_heads
45
+
46
+ assert feature_dim % num_heads == 0
47
+
48
+ self.query = nn.Linear(feature_dim, feature_dim)
49
+ self.key = nn.Linear(feature_dim, feature_dim)
50
+ self.value = nn.Linear(feature_dim, feature_dim)
51
+ self.dropout = nn.Dropout(dropout)
52
+ self.output_proj = nn.Linear(feature_dim, feature_dim)
53
+ self.layer_norm = nn.LayerNorm(feature_dim)
54
+
55
+ def forward(self, x, mask=None):
56
+ batch_size, seq_len, _ = x.size()
57
+
58
+ Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
59
+ K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
60
+ V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
61
+
62
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
63
+
64
+ if mask is not None:
65
+ if mask.dim() == 2:
66
+ mask = mask.unsqueeze(1).unsqueeze(1)
67
+ scores.masked_fill_(mask == 0, -1e9)
68
+
69
+ attn_weights = F.softmax(scores, dim=-1)
70
+ attn_weights = self.dropout(attn_weights)
71
+
72
+ context = torch.matmul(attn_weights, V)
73
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)
74
+
75
+ output = self.output_proj(context)
76
+ return self.layer_norm(output + x)
77
+
78
+ class StableLinguisticFeatureExtractor(nn.Module):
79
+ def __init__(self, config):
80
+ super().__init__()
81
+ self.config = config
82
+
83
+ # Simplified version - just return zeros if features not available
84
+ self.pos_embedding = nn.Embedding(config.get('pos_vocab_size', 150), config.get('pos_emb_dim', 64), padding_idx=0)
85
+ self.pos_attention = StableMultiHeadAttention(config.get('pos_emb_dim', 64), num_heads=4)
86
+
87
+ self.grammar_projection = nn.Sequential(
88
+ nn.Linear(config.get('grammar_dim', 3), config.get('grammar_hidden_dim', 64)),
89
+ nn.Tanh(),
90
+ nn.LayerNorm(config.get('grammar_hidden_dim', 64)),
91
+ nn.Dropout(config.get('dropout_rate', 0.3) * 0.3)
92
+ )
93
+
94
+ self.duration_projection = nn.Sequential(
95
+ nn.Linear(1, config.get('duration_hidden_dim', 128)),
96
+ nn.Tanh(),
97
+ nn.LayerNorm(config.get('duration_hidden_dim', 128))
98
+ )
99
+
100
+ self.prosody_projection = nn.Sequential(
101
+ nn.Linear(config.get('prosody_dim', 32), config.get('prosody_dim', 32)),
102
+ nn.ReLU(),
103
+ nn.LayerNorm(config.get('prosody_dim', 32))
104
+ )
105
+
106
+ total_feature_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) +
107
+ config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32))
108
+ self.feature_fusion = nn.Sequential(
109
+ nn.Linear(total_feature_dim, total_feature_dim // 2),
110
+ nn.Tanh(),
111
+ nn.LayerNorm(total_feature_dim // 2),
112
+ nn.Dropout(config.get('dropout_rate', 0.3))
113
+ )
114
+
115
+ def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
116
+ batch_size, seq_len = pos_ids.size()
117
+
118
+ # Simple processing - can be expanded later
119
+ pos_ids_clamped = pos_ids.clamp(0, self.config.get('pos_vocab_size', 150) - 1)
120
+ pos_embeds = self.pos_embedding(pos_ids_clamped)
121
+ pos_features = self.pos_attention(pos_embeds, attention_mask)
122
+
123
+ grammar_features = self.grammar_projection(grammar_ids.float())
124
+ duration_features = self.duration_projection(durations.unsqueeze(-1).float())
125
+ prosody_features = self.prosody_projection(prosody_features.float())
126
+
127
+ combined_features = torch.cat([
128
+ pos_features, grammar_features, duration_features, prosody_features
129
+ ], dim=-1)
130
+
131
+ fused_features = self.feature_fusion(combined_features)
132
+
133
+ mask_expanded = attention_mask.unsqueeze(-1).float()
134
+ pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1)
135
+
136
+ return pooled_features
137
+
138
+ class StableAphasiaClassifier(nn.Module):
139
+ def __init__(self, config, num_labels: int):
140
+ super().__init__()
141
+ self.config = config
142
+ self.num_labels = num_labels
143
+
144
+ try:
145
+ # Load the base BERT model
146
+ self.bert = AutoModel.from_pretrained(config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'))
147
+ self.bert_config = self.bert.config
148
+
149
+ # Freeze embeddings for stability
150
+ for param in self.bert.embeddings.parameters():
151
+ param.requires_grad = False
152
+
153
+ self.positional_encoder = StablePositionalEncoding(
154
+ d_model=self.bert_config.hidden_size,
155
+ max_len=config.get('max_length', 512)
156
+ )
157
+
158
+ self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
159
+
160
+ bert_dim = self.bert_config.hidden_size
161
+ linguistic_dim = (config.get('pos_emb_dim', 64) + config.get('grammar_hidden_dim', 64) +
162
+ config.get('duration_hidden_dim', 128) + config.get('prosody_dim', 32)) // 2
163
+
164
+ self.feature_fusion = nn.Sequential(
165
+ nn.Linear(bert_dim + linguistic_dim, bert_dim),
166
+ nn.LayerNorm(bert_dim),
167
+ nn.Tanh(),
168
+ nn.Dropout(config.get('dropout_rate', 0.3))
169
+ )
170
+
171
+ # Classifier
172
+ self.classifier = self._build_classifier(bert_dim, num_labels, config)
173
+
174
+ # Multi-task heads (simplified)
175
+ self.severity_head = nn.Sequential(
176
+ nn.Linear(bert_dim, 4),
177
+ nn.Softmax(dim=-1)
178
+ )
179
+
180
+ self.fluency_head = nn.Sequential(
181
+ nn.Linear(bert_dim, 1),
182
+ nn.Sigmoid()
183
+ )
184
+
185
+ except Exception as e:
186
+ logger.error(f"Error initializing model: {e}")
187
+ raise
188
+
189
+ def _build_classifier(self, input_dim: int, num_labels: int, config):
190
+ layers = []
191
+ current_dim = input_dim
192
+
193
+ classifier_hidden_dims = config.get('classifier_hidden_dims', [512, 256])
194
+ for hidden_dim in classifier_hidden_dims:
195
+ layers.extend([
196
+ nn.Linear(current_dim, hidden_dim),
197
+ nn.LayerNorm(hidden_dim),
198
+ nn.Tanh(),
199
+ nn.Dropout(config.get('dropout_rate', 0.3))
200
+ ])
201
+ current_dim = hidden_dim
202
+
203
+ layers.append(nn.Linear(current_dim, num_labels))
204
+ return nn.Sequential(*layers)
205
+
206
+ def forward(self, input_ids, attention_mask, labels=None,
207
+ word_pos_ids=None, word_grammar_ids=None, word_durations=None,
208
+ prosody_features=None, **kwargs):
209
+
210
+ bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
211
+ sequence_output = bert_outputs.last_hidden_state
212
+
213
+ position_enhanced = self.positional_encoder(sequence_output)
214
+ pooled_output = self._attention_pooling(position_enhanced, attention_mask)
215
+
216
+ # Handle missing linguistic features
217
+ if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
218
+ if prosody_features is None:
219
+ batch_size, seq_len = input_ids.size()
220
+ prosody_features = torch.zeros(
221
+ batch_size, seq_len, self.config.get('prosody_dim', 32),
222
+ device=input_ids.device
223
+ )
224
+
225
+ linguistic_features = self.linguistic_extractor(
226
+ word_pos_ids, word_grammar_ids, word_durations,
227
+ prosody_features, attention_mask
228
+ )
229
+ else:
230
+ # Create dummy linguistic features
231
+ linguistic_features = torch.zeros(
232
+ input_ids.size(0),
233
+ (self.config.get('pos_emb_dim', 64) + self.config.get('grammar_hidden_dim', 64) +
234
+ self.config.get('duration_hidden_dim', 128) + self.config.get('prosody_dim', 32)) // 2,
235
+ device=input_ids.device
236
+ )
237
+
238
+ combined_features = torch.cat([pooled_output, linguistic_features], dim=1)
239
+ fused_features = self.feature_fusion(combined_features)
240
+
241
+ logits = self.classifier(fused_features)
242
+ severity_pred = self.severity_head(fused_features)
243
+ fluency_pred = self.fluency_head(fused_features)
244
+
245
+ return {
246
+ "logits": logits,
247
+ "severity_pred": severity_pred,
248
+ "fluency_pred": fluency_pred,
249
+ }
250
+
251
+ def _attention_pooling(self, sequence_output, attention_mask):
252
+ attention_weights = torch.softmax(
253
+ torch.sum(sequence_output, dim=-1, keepdim=True), dim=1
254
+ )
255
+ attention_weights = attention_weights * attention_mask.unsqueeze(-1).float()
256
+ attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9)
257
+ pooled = torch.sum(sequence_output * attention_weights, dim=1)
258
+ return pooled
259
+
260
+ # Load configuration and model
261
+ def load_model():
262
+ try:
263
+ # Load configuration
264
+ with open("config.json", "r") as f:
265
+ config = json.load(f)
266
+
267
+ logger.info(f"Loaded config: {config}")
268
+
269
+ # Initialize tokenizer
270
+ model_name = config.get('model_name', 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
271
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
272
+
273
+ # Add special tokens
274
+ special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"]
275
+ tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
276
+
277
+ # Initialize model
278
+ num_labels = config.get('num_labels', 9)
279
+ model_config = config.get('model_config', {})
280
+
281
+ model = StableAphasiaClassifier(model_config, num_labels)
282
+ model.bert.resize_token_embeddings(len(tokenizer))
283
+
284
+ # Load model weights
285
+ try:
286
+ state_dict = torch.load("pytorch_model.bin", map_location="cpu")
287
+ model.load_state_dict(state_dict)
288
+ logger.info("Successfully loaded model weights")
289
+ except Exception as e:
290
+ logger.error(f"Error loading model weights: {e}")
291
+ logger.info("Using randomly initialized weights")
292
+
293
+ model.eval()
294
+
295
+ # Get label mapping
296
+ id2label = config.get('id2label', {})
297
+
298
+ return model, tokenizer, id2label
299
+
300
+ except Exception as e:
301
+ logger.error(f"Error loading model: {e}")
302
+ raise
303
+
304
+ # Initialize model (with error handling)
305
+ try:
306
+ model, tokenizer, id2label = load_model()
307
+ logger.info("Model loaded successfully!")
308
+ except Exception as e:
309
+ logger.error(f"Failed to load model: {e}")
310
+ model, tokenizer, id2label = None, None, {}
311
+
312
+ def predict_aphasia(text):
313
+ """Predict aphasia type from text"""
314
+ try:
315
+ if model is None or tokenizer is None:
316
+ return "Error: Model not loaded properly. Please check the logs.", 0.0, "N/A", 0.0
317
+
318
+ if not text or not text.strip():
319
+ return "Please enter some text for analysis.", 0.0, "N/A", 0.0
320
+
321
+ # Tokenize input
322
+ inputs = tokenizer(
323
+ text,
324
+ max_length=512,
325
+ padding="max_length",
326
+ truncation=True,
327
+ return_tensors="pt"
328
+ )
329
+
330
+ # Create dummy linguistic features (since we don't have them from raw text)
331
+ batch_size, seq_len = inputs["input_ids"].size()
332
+ dummy_pos = torch.zeros(batch_size, seq_len, dtype=torch.long)
333
+ dummy_grammar = torch.zeros(batch_size, seq_len, 3, dtype=torch.long)
334
+ dummy_durations = torch.zeros(batch_size, seq_len, dtype=torch.float)
335
+ dummy_prosody = torch.zeros(batch_size, seq_len, 32, dtype=torch.float)
336
+
337
+ # Make prediction
338
+ with torch.no_grad():
339
+ outputs = model(
340
+ input_ids=inputs["input_ids"],
341
+ attention_mask=inputs["attention_mask"],
342
+ word_pos_ids=dummy_pos,
343
+ word_grammar_ids=dummy_grammar,
344
+ word_durations=dummy_durations,
345
+ prosody_features=dummy_prosody
346
+ )
347
+
348
+ # Process outputs
349
+ logits = outputs["logits"]
350
+ probs = F.softmax(logits, dim=1)
351
+ predicted_class_id = torch.argmax(probs, dim=1).item()
352
+ confidence = torch.max(probs, dim=1)[0].item()
353
+
354
+ # Get predicted label
355
+ predicted_label = id2label.get(str(predicted_class_id), f"Class_{predicted_class_id}")
356
+
357
+ # Get additional predictions
358
+ severity = torch.argmax(outputs["severity_pred"], dim=1).item()
359
+ fluency = outputs["fluency_pred"].item()
360
+
361
+ # Format results
362
+ result = f"Predicted Aphasia Type: {predicted_label}"
363
+ confidence_str = f"Confidence: {confidence:.2%}"
364
+ severity_str = f"Severity Level: {severity}/3"
365
+ fluency_str = f"Fluency Score: {fluency:.3f}"
366
+
367
+ return result, confidence, severity_str, fluency_str
368
+
369
+ except Exception as e:
370
+ logger.error(f"Prediction error: {e}")
371
+ return f"Error during prediction: {str(e)}", 0.0, "N/A", 0.0
372
+
373
+ # Create Gradio interface
374
+ def create_interface():
375
+ """Create Gradio interface"""
376
+
377
+ with gr.Blocks(title="Aphasia Classification System") as demo:
378
+ gr.Markdown("# 🧠 Advanced Aphasia Classification System")
379
+ gr.Markdown("Enter speech or text data to classify aphasia type and analyze linguistic patterns.")
380
+
381
+ with gr.Row():
382
+ with gr.Column():
383
+ text_input = gr.Textbox(
384
+ label="Input Text",
385
+ placeholder="Enter speech transcription or text for analysis...",
386
+ lines=5
387
+ )
388
+
389
+ submit_btn = gr.Button("Analyze Text", variant="primary")
390
+ clear_btn = gr.Button("Clear", variant="secondary")
391
+
392
+ with gr.Column():
393
+ prediction_output = gr.Textbox(label="Prediction Result", lines=2)
394
+ confidence_output = gr.Textbox(label="Confidence Score", lines=1)
395
+ severity_output = gr.Textbox(label="Severity Analysis", lines=1)
396
+ fluency_output = gr.Textbox(label="Fluency Analysis", lines=1)
397
+
398
+ # Event handlers
399
+ submit_btn.click(
400
+ fn=predict_aphasia,
401
+ inputs=[text_input],
402
+ outputs=[prediction_output, confidence_output, severity_output, fluency_output]
403
+ )
404
+
405
+ clear_btn.click(
406
+ fn=lambda: ("", "", "", "", ""),
407
+ inputs=[],
408
+ outputs=[text_input, prediction_output, confidence_output, severity_output, fluency_output]
409
+ )
410
+
411
+ # Add examples
412
+ gr.Examples(
413
+ examples=[
414
+ ["The patient... uh... wants to... go home but... cannot... find the words"],
415
+ ["Woman is... is washing dishes and the... the... sink is overflowing with water everywhere"],
416
+ ["Cookie is in the cookie jar on the... on the... what do you call it... the shelf thing"]
417
+ ],
418
+ inputs=text_input
419
+ )
420
+
421
+ gr.Markdown("### About")
422
+ gr.Markdown("This system uses a specialized transformer model trained on clinical speech data to classify different types of aphasia.")
423
+
424
+ return demo
425
+
426
+ # Launch the app
427
  if __name__ == "__main__":
428
+ try:
429
+ demo = create_interface()
430
+ demo.launch(
431
+ server_name="0.0.0.0",
432
+ server_port=7860,
433
+ show_error=True
434
+ )
435
+ except Exception as e:
436
+ logger.error(f"Failed to launch app: {e}")
437
+ print(f"Application startup failed: {e}")