Ellie5757575757 commited on
Commit
d6f51e7
·
verified ·
1 Parent(s): 0c63b16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -192
app.py CHANGED
@@ -10,8 +10,8 @@ from pathlib import Path
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Configuration - Use current directory for model files
14
- MODEL_DIR = "." # Changed from "./adaptive_aphasia_model"
15
  SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"]
16
 
17
  def safe_import_modules():
@@ -55,31 +55,9 @@ def safe_import_modules():
55
  # Import modules
56
  MODULES = safe_import_modules()
57
 
58
- def check_model_files():
59
- """Check if required model files exist"""
60
- required_files = [
61
- "pytorch_model.bin",
62
- "config.json",
63
- "tokenizer.json",
64
- "tokenizer_config.json"
65
- ]
66
-
67
- missing_files = []
68
- for file in required_files:
69
- if not os.path.exists(os.path.join(MODEL_DIR, file)):
70
- missing_files.append(file)
71
-
72
- if missing_files:
73
- logger.error(f"Missing model files: {missing_files}")
74
- return False, missing_files
75
-
76
- logger.info("✓ All required model files found")
77
- return True, []
78
-
79
  def run_complete_pipeline(audio_file_path: str) -> dict:
80
  """Complete pipeline: Audio → WAV → CHA → JSON → Model Prediction"""
81
 
82
- # Check if all modules are available
83
  if not all(MODULES.values()):
84
  missing = [k for k, v in MODULES.items() if v is None]
85
  return {
@@ -140,15 +118,13 @@ def process_audio_input(audio_file):
140
  if audio_file is None:
141
  return "❌ Error: No audio file uploaded"
142
 
143
- # Check if pipeline is available
144
  if not all(MODULES.values()):
145
  return "❌ Error: Audio processing pipeline not available. Missing required modules."
146
 
147
- # Check file format
148
- file_path = audio_file
149
- if hasattr(audio_file, 'name'):
150
- file_path = audio_file.name
151
 
 
152
  file_ext = Path(file_path).suffix.lower()
153
  if file_ext not in SUPPORTED_AUDIO_FORMATS:
154
  return f"❌ Error: Unsupported file format {file_ext}. Supported: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
@@ -184,8 +160,7 @@ def process_audio_input(audio_file):
184
  prob_dist = first_pred["probability_distribution"]
185
  top_3 = list(prob_dist.items())[:3]
186
 
187
- result_text = f"""
188
- 🧠 **APHASIA CLASSIFICATION RESULTS**
189
 
190
  🎯 **Primary Classification:** {predicted_class}
191
  📊 **Confidence:** {confidence}
@@ -207,11 +182,9 @@ def process_audio_input(audio_file):
207
  📊 **Processing Summary:**
208
  • Total sentences analyzed: {results.get('total_sentences', 'N/A')}
209
  • Average confidence: {results.get('summary', {}).get('average_confidence', 'N/A')}
210
- • Average fluency: {results.get('summary', {}).get('average_fluency_score', 'N/A')}
211
  """
212
 
213
  return result_text
214
-
215
  else:
216
  return "❌ No predictions generated. The audio file may not contain analyzable speech."
217
 
@@ -220,176 +193,39 @@ def process_audio_input(audio_file):
220
  logger.error(traceback.format_exc())
221
  return f"❌ Processing Error: {str(e)}\n\nPlease check the logs for more details."
222
 
223
- def process_text_input(text_input):
224
- """Process text input directly (fallback option)"""
225
- try:
226
- if not text_input or not text_input.strip():
227
- return "❌ Error: Please enter some text for analysis"
228
-
229
- # Check if prediction module is available
230
- if MODULES['predict_from_chajson'] is None:
231
- return "❌ Error: Text analysis not available. Missing prediction module."
232
-
233
- # Create a simple JSON structure for text-only input
234
- temp_json = {
235
- "sentences": [{
236
- "sentence_id": "S1",
237
- "aphasia_type": "UNKNOWN",
238
- "dialogues": [{
239
- "INV": [],
240
- "PAR": [{
241
- "tokens": text_input.split(),
242
- "word_pos_ids": [0] * len(text_input.split()),
243
- "word_grammar_ids": [[0, 0, 0]] * len(text_input.split()),
244
- "word_durations": [0.0] * len(text_input.split()),
245
- "utterance_text": text_input
246
- }]
247
- }]
248
- }],
249
- "text_all": text_input
250
- }
251
-
252
- # Save to temporary file
253
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
254
- json.dump(temp_json, f, ensure_ascii=False, indent=2)
255
- temp_json_path = f.name
256
-
257
- # Run prediction
258
- results = MODULES['predict_from_chajson'](MODEL_DIR, temp_json_path, output_file=None)
259
-
260
- # Cleanup
261
- try:
262
- os.unlink(temp_json_path)
263
- except:
264
- pass
265
-
266
- # Format results
267
- if "predictions" in results and len(results["predictions"]) > 0:
268
- first_pred = results["predictions"][0]
269
-
270
- predicted_class = first_pred["prediction"]["predicted_class"]
271
- confidence = first_pred["prediction"]["confidence_percentage"]
272
- description = first_pred["class_description"]["description"]
273
- severity = first_pred["additional_predictions"]["predicted_severity_level"]
274
- fluency = first_pred["additional_predictions"]["fluency_rating"]
275
-
276
- return f"""
277
- 🧠 **TEXT ANALYSIS RESULTS**
278
-
279
- 🎯 **Predicted:** {predicted_class}
280
- 📊 **Confidence:** {confidence}
281
- 📈 **Severity:** {severity}/3
282
- 🗣️ **Fluency:** {fluency}
283
-
284
- 📝 **Description:**
285
- {description}
286
-
287
- ℹ️ **Note:** Text-based analysis provides limited accuracy compared to audio analysis.
288
- """
289
- else:
290
- return "❌ No predictions generated from text input"
291
-
292
- except Exception as e:
293
- logger.error(f"Text processing error: {str(e)}")
294
- return f"❌ Error: {str(e)}"
295
-
296
- def create_interface():
297
  """Create simplified Gradio interface"""
298
 
299
- # Check system status
300
- model_available, missing_files = check_model_files()
301
- pipeline_available = all(MODULES.values())
302
-
303
- status_message = "🟢 **System Status: Ready**" if model_available and pipeline_available else "🔴 **System Status: Issues Detected**"
304
-
305
- if not model_available:
306
- status_message += f"\n❌ Missing model files: {', '.join(missing_files)}"
307
-
308
- if not pipeline_available:
309
- missing_modules = [k for k, v in MODULES.items() if v is None]
310
- status_message += f"\n❌ Missing modules: {', '.join(missing_modules)}"
311
-
312
- # Create interface
313
- with gr.Blocks(title="Aphasia Classification System") as demo:
314
-
315
- gr.HTML("""
316
- <div style="text-align: center; margin-bottom: 2rem;">
317
- <h1>🧠 Advanced Aphasia Classification System</h1>
318
- <p>Upload audio files or enter text to analyze speech patterns and classify aphasia types</p>
319
- </div>
320
- """)
321
-
322
- gr.Markdown(status_message)
323
-
324
- with gr.Tabs():
325
- # Audio Tab
326
- with gr.Tab("🎵 Audio Analysis"):
327
- gr.Markdown("### Upload Audio File")
328
- gr.Markdown("**Supported formats:** MP3, MP4, WAV, M4A, FLAC, OGG")
329
-
330
- audio_input = gr.File(
331
- label="Upload Audio File",
332
- file_types=["audio"]
333
- )
334
-
335
- audio_btn = gr.Button("🔍 Analyze Audio", variant="primary")
336
-
337
- audio_output = gr.Textbox(
338
- label="Analysis Results",
339
- lines=20,
340
- max_lines=30
341
- )
342
-
343
- audio_btn.click(
344
- fn=process_audio_input,
345
- inputs=audio_input,
346
- outputs=audio_output
347
- )
348
-
349
- # Text Tab
350
- with gr.Tab("📝 Text Analysis"):
351
- gr.Markdown("### Direct Text Input")
352
- gr.Markdown("**Note:** Audio analysis provides more accurate results")
353
-
354
- text_input = gr.Textbox(
355
- label="Enter Text",
356
- placeholder="Enter speech transcription or text for analysis...",
357
- lines=5
358
- )
359
-
360
- text_btn = gr.Button("🔍 Analyze Text", variant="secondary")
361
-
362
- text_output = gr.Textbox(
363
- label="Analysis Results",
364
- lines=15,
365
- max_lines=20
366
- )
367
-
368
- text_btn.click(
369
- fn=process_text_input,
370
- inputs=text_input,
371
- outputs=text_output
372
- )
373
-
374
- gr.HTML("""
375
- <div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid #eee;">
376
- <p><strong>About:</strong> This system uses advanced NLP and acoustic analysis to classify different types of aphasia.</p>
377
- <p><em>For research and clinical assessment purposes.</em></p>
378
- </div>
379
- """)
380
-
381
- return demo
382
 
383
  if __name__ == "__main__":
384
  try:
385
  logger.info("Starting Aphasia Classification System...")
386
 
387
  # Create and launch interface
388
- demo = create_interface()
389
  demo.launch(
390
  server_name="0.0.0.0",
391
  server_port=7860,
392
- share=True, # This fixes the localhost error
393
  show_error=True
394
  )
395
 
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Configuration
14
+ MODEL_DIR = "."
15
  SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"]
16
 
17
  def safe_import_modules():
 
55
  # Import modules
56
  MODULES = safe_import_modules()
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def run_complete_pipeline(audio_file_path: str) -> dict:
59
  """Complete pipeline: Audio → WAV → CHA → JSON → Model Prediction"""
60
 
 
61
  if not all(MODULES.values()):
62
  missing = [k for k, v in MODULES.items() if v is None]
63
  return {
 
118
  if audio_file is None:
119
  return "❌ Error: No audio file uploaded"
120
 
 
121
  if not all(MODULES.values()):
122
  return "❌ Error: Audio processing pipeline not available. Missing required modules."
123
 
124
+ # Get file path
125
+ file_path = audio_file.name if hasattr(audio_file, 'name') else str(audio_file)
 
 
126
 
127
+ # Check file format
128
  file_ext = Path(file_path).suffix.lower()
129
  if file_ext not in SUPPORTED_AUDIO_FORMATS:
130
  return f"❌ Error: Unsupported file format {file_ext}. Supported: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
 
160
  prob_dist = first_pred["probability_distribution"]
161
  top_3 = list(prob_dist.items())[:3]
162
 
163
+ result_text = f"""🧠 **APHASIA CLASSIFICATION RESULTS**
 
164
 
165
  🎯 **Primary Classification:** {predicted_class}
166
  📊 **Confidence:** {confidence}
 
182
  📊 **Processing Summary:**
183
  • Total sentences analyzed: {results.get('total_sentences', 'N/A')}
184
  • Average confidence: {results.get('summary', {}).get('average_confidence', 'N/A')}
 
185
  """
186
 
187
  return result_text
 
188
  else:
189
  return "❌ No predictions generated. The audio file may not contain analyzable speech."
190
 
 
193
  logger.error(traceback.format_exc())
194
  return f"❌ Processing Error: {str(e)}\n\nPlease check the logs for more details."
195
 
196
+ # Create simple interface
197
+ def create_simple_interface():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  """Create simplified Gradio interface"""
199
 
200
+ # Create interface with basic components only
201
+ iface = gr.Interface(
202
+ fn=process_audio_input,
203
+ inputs=gr.File(
204
+ label="Upload Audio File (MP3, MP4, WAV, M4A, FLAC, OGG)",
205
+ file_types=["audio"]
206
+ ),
207
+ outputs=gr.Textbox(
208
+ label="Analysis Results",
209
+ lines=25,
210
+ max_lines=50
211
+ ),
212
+ title="🧠 Aphasia Classification System",
213
+ description="Upload audio files to analyze speech patterns and classify aphasia types",
214
+ article="<p><strong>About:</strong> This system uses advanced NLP and acoustic analysis to classify different types of aphasia.</p><p><em>For research and clinical assessment purposes.</em></p>"
215
+ )
216
+
217
+ return iface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
  try:
221
  logger.info("Starting Aphasia Classification System...")
222
 
223
  # Create and launch interface
224
+ demo = create_simple_interface()
225
  demo.launch(
226
  server_name="0.0.0.0",
227
  server_port=7860,
228
+ share=False, # Set to False to avoid the share warning
229
  show_error=True
230
  )
231