Ellie5757575757 commited on
Commit
bf5780d
Β·
verified Β·
1 Parent(s): d6f51e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -19
app.py CHANGED
@@ -10,7 +10,7 @@ from pathlib import Path
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Configuration
14
  MODEL_DIR = "."
15
  SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"]
16
 
@@ -55,9 +55,31 @@ def safe_import_modules():
55
  # Import modules
56
  MODULES = safe_import_modules()
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def run_complete_pipeline(audio_file_path: str) -> dict:
59
  """Complete pipeline: Audio β†’ WAV β†’ CHA β†’ JSON β†’ Model Prediction"""
60
 
 
61
  if not all(MODULES.values()):
62
  missing = [k for k, v in MODULES.items() if v is None]
63
  return {
@@ -118,13 +140,15 @@ def process_audio_input(audio_file):
118
  if audio_file is None:
119
  return "❌ Error: No audio file uploaded"
120
 
 
121
  if not all(MODULES.values()):
122
  return "❌ Error: Audio processing pipeline not available. Missing required modules."
123
 
124
- # Get file path
125
- file_path = audio_file.name if hasattr(audio_file, 'name') else str(audio_file)
126
-
127
  # Check file format
 
 
 
 
128
  file_ext = Path(file_path).suffix.lower()
129
  if file_ext not in SUPPORTED_AUDIO_FORMATS:
130
  return f"❌ Error: Unsupported file format {file_ext}. Supported: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
@@ -160,7 +184,8 @@ def process_audio_input(audio_file):
160
  prob_dist = first_pred["probability_distribution"]
161
  top_3 = list(prob_dist.items())[:3]
162
 
163
- result_text = f"""🧠 **APHASIA CLASSIFICATION RESULTS**
 
164
 
165
  🎯 **Primary Classification:** {predicted_class}
166
  πŸ“Š **Confidence:** {confidence}
@@ -182,9 +207,11 @@ def process_audio_input(audio_file):
182
  πŸ“Š **Processing Summary:**
183
  β€’ Total sentences analyzed: {results.get('total_sentences', 'N/A')}
184
  β€’ Average confidence: {results.get('summary', {}).get('average_confidence', 'N/A')}
 
185
  """
186
 
187
  return result_text
 
188
  else:
189
  return "❌ No predictions generated. The audio file may not contain analyzable speech."
190
 
@@ -193,12 +220,124 @@ def process_audio_input(audio_file):
193
  logger.error(traceback.format_exc())
194
  return f"❌ Processing Error: {str(e)}\n\nPlease check the logs for more details."
195
 
196
- # Create simple interface
197
- def create_simple_interface():
198
- """Create simplified Gradio interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- # Create interface with basic components only
201
- iface = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  fn=process_audio_input,
203
  inputs=gr.File(
204
  label="Upload Audio File (MP3, MP4, WAV, M4A, FLAC, OGG)",
@@ -211,23 +350,48 @@ def create_simple_interface():
211
  ),
212
  title="🧠 Aphasia Classification System",
213
  description="Upload audio files to analyze speech patterns and classify aphasia types",
214
- article="<p><strong>About:</strong> This system uses advanced NLP and acoustic analysis to classify different types of aphasia.</p><p><em>For research and clinical assessment purposes.</em></p>"
 
 
 
 
 
 
 
 
 
215
  )
216
 
217
- return iface
218
 
219
  if __name__ == "__main__":
220
  try:
221
  logger.info("Starting Aphasia Classification System...")
222
 
 
 
 
 
223
  # Create and launch interface
224
- demo = create_simple_interface()
225
- demo.launch(
226
- server_name="0.0.0.0",
227
- server_port=7860,
228
- share=False, # Set to False to avoid the share warning
229
- show_error=True
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  except Exception as e:
233
  logger.error(f"Failed to launch app: {e}")
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Configuration - Use current directory for model files
14
  MODEL_DIR = "."
15
  SUPPORTED_AUDIO_FORMATS = [".mp3", ".mp4", ".wav", ".m4a", ".flac", ".ogg"]
16
 
 
55
  # Import modules
56
  MODULES = safe_import_modules()
57
 
58
+ def check_model_files():
59
+ """Check if required model files exist"""
60
+ required_files = [
61
+ "pytorch_model.bin",
62
+ "config.json",
63
+ "tokenizer.json",
64
+ "tokenizer_config.json"
65
+ ]
66
+
67
+ missing_files = []
68
+ for file in required_files:
69
+ if not os.path.exists(os.path.join(MODEL_DIR, file)):
70
+ missing_files.append(file)
71
+
72
+ if missing_files:
73
+ logger.error(f"Missing model files: {missing_files}")
74
+ return False, missing_files
75
+
76
+ logger.info("βœ“ All required model files found")
77
+ return True, []
78
+
79
  def run_complete_pipeline(audio_file_path: str) -> dict:
80
  """Complete pipeline: Audio β†’ WAV β†’ CHA β†’ JSON β†’ Model Prediction"""
81
 
82
+ # Check if all modules are available
83
  if not all(MODULES.values()):
84
  missing = [k for k, v in MODULES.items() if v is None]
85
  return {
 
140
  if audio_file is None:
141
  return "❌ Error: No audio file uploaded"
142
 
143
+ # Check if pipeline is available
144
  if not all(MODULES.values()):
145
  return "❌ Error: Audio processing pipeline not available. Missing required modules."
146
 
 
 
 
147
  # Check file format
148
+ file_path = audio_file
149
+ if hasattr(audio_file, 'name'):
150
+ file_path = audio_file.name
151
+
152
  file_ext = Path(file_path).suffix.lower()
153
  if file_ext not in SUPPORTED_AUDIO_FORMATS:
154
  return f"❌ Error: Unsupported file format {file_ext}. Supported: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
 
184
  prob_dist = first_pred["probability_distribution"]
185
  top_3 = list(prob_dist.items())[:3]
186
 
187
+ result_text = f"""
188
+ 🧠 **APHASIA CLASSIFICATION RESULTS**
189
 
190
  🎯 **Primary Classification:** {predicted_class}
191
  πŸ“Š **Confidence:** {confidence}
 
207
  πŸ“Š **Processing Summary:**
208
  β€’ Total sentences analyzed: {results.get('total_sentences', 'N/A')}
209
  β€’ Average confidence: {results.get('summary', {}).get('average_confidence', 'N/A')}
210
+ β€’ Average fluency: {results.get('summary', {}).get('average_fluency_score', 'N/A')}
211
  """
212
 
213
  return result_text
214
+
215
  else:
216
  return "❌ No predictions generated. The audio file may not contain analyzable speech."
217
 
 
220
  logger.error(traceback.format_exc())
221
  return f"❌ Processing Error: {str(e)}\n\nPlease check the logs for more details."
222
 
223
+ def process_text_input(text_input):
224
+ """Process text input directly (fallback option)"""
225
+ try:
226
+ if not text_input or not text_input.strip():
227
+ return "❌ Error: Please enter some text for analysis"
228
+
229
+ # Check if prediction module is available
230
+ if MODULES['predict_from_chajson'] is None:
231
+ return "❌ Error: Text analysis not available. Missing prediction module."
232
+
233
+ # Create a simple JSON structure for text-only input
234
+ temp_json = {
235
+ "sentences": [{
236
+ "sentence_id": "S1",
237
+ "aphasia_type": "UNKNOWN",
238
+ "dialogues": [{
239
+ "INV": [],
240
+ "PAR": [{
241
+ "tokens": text_input.split(),
242
+ "word_pos_ids": [0] * len(text_input.split()),
243
+ "word_grammar_ids": [[0, 0, 0]] * len(text_input.split()),
244
+ "word_durations": [0.0] * len(text_input.split()),
245
+ "utterance_text": text_input
246
+ }]
247
+ }]
248
+ }],
249
+ "text_all": text_input
250
+ }
251
+
252
+ # Save to temporary file
253
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
254
+ json.dump(temp_json, f, ensure_ascii=False, indent=2)
255
+ temp_json_path = f.name
256
+
257
+ # Run prediction
258
+ results = MODULES['predict_from_chajson'](MODEL_DIR, temp_json_path, output_file=None)
259
+
260
+ # Cleanup
261
+ try:
262
+ os.unlink(temp_json_path)
263
+ except:
264
+ pass
265
+
266
+ # Format results
267
+ if "predictions" in results and len(results["predictions"]) > 0:
268
+ first_pred = results["predictions"][0]
269
+
270
+ predicted_class = first_pred["prediction"]["predicted_class"]
271
+ confidence = first_pred["prediction"]["confidence_percentage"]
272
+ description = first_pred["class_description"]["description"]
273
+ severity = first_pred["additional_predictions"]["predicted_severity_level"]
274
+ fluency = first_pred["additional_predictions"]["fluency_rating"]
275
+
276
+ return f"""
277
+ 🧠 **TEXT ANALYSIS RESULTS**
278
+
279
+ 🎯 **Predicted:** {predicted_class}
280
+ πŸ“Š **Confidence:** {confidence}
281
+ πŸ“ˆ **Severity:** {severity}/3
282
+ πŸ—£οΈ **Fluency:** {fluency}
283
+
284
+ πŸ“ **Description:**
285
+ {description}
286
+
287
+ ℹ️ **Note:** Text-based analysis provides limited accuracy compared to audio analysis.
288
+ """
289
+ else:
290
+ return "❌ No predictions generated from text input"
291
+
292
+ except Exception as e:
293
+ logger.error(f"Text processing error: {str(e)}")
294
+ return f"❌ Error: {str(e)}"
295
+
296
+ def detect_environment():
297
+ """Detect if we're running in a cloud environment"""
298
+ # Check for common cloud environment indicators
299
+ cloud_indicators = [
300
+ 'SPACE_ID', # Hugging Face Spaces
301
+ 'PAPERSPACE_NOTEBOOK_REPO_ID', # Paperspace
302
+ 'COLAB_GPU', # Google Colab
303
+ 'KAGGLE_KERNEL_RUN_TYPE', # Kaggle
304
+ 'AWS_LAMBDA_FUNCTION_NAME', # AWS Lambda
305
+ ]
306
+
307
+ is_cloud = any(os.getenv(indicator) for indicator in cloud_indicators)
308
+
309
+ # Also check if we can access localhost
310
+ import socket
311
+ localhost_accessible = False
312
+ try:
313
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
314
+ sock.settimeout(1)
315
+ result = sock.connect_ex(('127.0.0.1', 7860))
316
+ localhost_accessible = (result == 0)
317
+ sock.close()
318
+ except:
319
+ localhost_accessible = False
320
 
321
+ return is_cloud, localhost_accessible
322
+
323
+ def create_interface():
324
+ """Create Gradio interface with proper configuration"""
325
+
326
+ # Check system status
327
+ model_available, missing_files = check_model_files()
328
+ pipeline_available = all(MODULES.values())
329
+
330
+ status_message = "🟒 **System Status: Ready**" if model_available and pipeline_available else "πŸ”΄ **System Status: Issues Detected**"
331
+
332
+ if not model_available:
333
+ status_message += f"\n❌ Missing model files: {', '.join(missing_files)}"
334
+
335
+ if not pipeline_available:
336
+ missing_modules = [k for k, v in MODULES.items() if v is None]
337
+ status_message += f"\n❌ Missing modules: {', '.join(missing_modules)}"
338
+
339
+ # Create interface using simple Interface instead of Blocks to avoid JSON schema issues
340
+ audio_interface = gr.Interface(
341
  fn=process_audio_input,
342
  inputs=gr.File(
343
  label="Upload Audio File (MP3, MP4, WAV, M4A, FLAC, OGG)",
 
350
  ),
351
  title="🧠 Aphasia Classification System",
352
  description="Upload audio files to analyze speech patterns and classify aphasia types",
353
+ article=f"""
354
+ <div style="margin-top: 20px;">
355
+ <h3>System Status</h3>
356
+ <p>{status_message}</p>
357
+ <h3>About</h3>
358
+ <p><strong>Pipeline:</strong> Audio β†’ WAV β†’ CHA β†’ JSON β†’ Classification</p>
359
+ <p><strong>Supported formats:</strong> MP3, MP4, WAV, M4A, FLAC, OGG</p>
360
+ <p><em>For research and clinical assessment purposes.</em></p>
361
+ </div>
362
+ """
363
  )
364
 
365
+ return audio_interface
366
 
367
  if __name__ == "__main__":
368
  try:
369
  logger.info("Starting Aphasia Classification System...")
370
 
371
+ # Detect environment
372
+ is_cloud, localhost_accessible = detect_environment()
373
+ logger.info(f"Environment - Cloud: {is_cloud}, Localhost accessible: {localhost_accessible}")
374
+
375
  # Create and launch interface
376
+ demo = create_interface()
377
+
378
+ # Configure launch parameters based on environment
379
+ launch_kwargs = {
380
+ "server_name": "0.0.0.0",
381
+ "server_port": 7860,
382
+ "show_error": True,
383
+ "quiet": False,
384
+ }
385
+
386
+ # Set share parameter based on environment
387
+ if is_cloud or not localhost_accessible:
388
+ launch_kwargs["share"] = True
389
+ logger.info("Running in cloud environment or localhost not accessible - enabling share")
390
+ else:
391
+ launch_kwargs["share"] = False
392
+ logger.info("Running locally - share disabled")
393
+
394
+ demo.launch(**launch_kwargs)
395
 
396
  except Exception as e:
397
  logger.error(f"Failed to launch app: {e}")