Ellie5757575757 commited on
Commit
5908912
·
verified ·
1 Parent(s): 256d133

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -328
app.py CHANGED
@@ -1,333 +1,17 @@
1
  import gradio as gr
2
- import torch
3
- import json
4
- import re
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
- import pandas as pd
7
- from datetime import datetime
8
- import os
9
 
10
- class AphasiaClassifier:
11
- def __init__(self, model_path="./pytorch_model.bin", tokenizer_name="dmis-lab/biobert-base-cased-v1.1"):
12
- """
13
- Initialize the Aphasia Classifier
14
-
15
- Args:
16
- model_path: Path to the fine-tuned pytorch_model.bin
17
- tokenizer_name: Name of the tokenizer to use (BioBERT)
18
- """
19
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
20
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
-
22
- # Load the model - you'll need to adjust this based on your model architecture
23
- try:
24
- # Assuming you have a config.json file with your model configuration
25
- self.model = AutoModelForSequenceClassification.from_pretrained(
26
- "./",
27
- local_files_only=True
28
- )
29
- self.model.to(self.device)
30
- self.model.eval()
31
- except:
32
- # Fallback: create a placeholder model structure
33
- print("Warning: Could not load model. Using placeholder structure.")
34
- self.model = None
35
-
36
- # Define aphasia severity labels (adjust based on your model's classes)
37
- self.severity_labels = {
38
- 0: "Normal",
39
- 1: "Mild Aphasia",
40
- 2: "Moderate Aphasia",
41
- 3: "Severe Aphasia"
42
- }
43
-
44
- def preprocess_to_cha(self, text_input):
45
- """
46
- Convert text input to CHA format
47
-
48
- Args:
49
- text_input: Raw text input from user
50
-
51
- Returns:
52
- cha_formatted: Text formatted in CHA format
53
- """
54
- # Basic CHA formatting - adjust based on your specific CHA requirements
55
- lines = text_input.strip().split('\n')
56
- cha_formatted = []
57
-
58
- for i, line in enumerate(lines):
59
- if line.strip():
60
- # Format as CHA with participant markers
61
- cha_line = f"*PAR:\t{line.strip()}"
62
- cha_formatted.append(cha_line)
63
-
64
- return '\n'.join(cha_formatted)
65
-
66
- def cha_to_json(self, cha_text):
67
- """
68
- Convert CHA format to JSON structure
69
-
70
- Args:
71
- cha_text: Text in CHA format
72
-
73
- Returns:
74
- json_data: Structured JSON data
75
- """
76
- lines = cha_text.split('\n')
77
- utterances = []
78
-
79
- for line in lines:
80
- if line.startswith('*PAR:'):
81
- # Extract the actual speech content
82
- content = line.replace('*PAR:', '').strip()
83
- if content:
84
- utterances.append({
85
- "speaker": "PAR",
86
- "utterance": content,
87
- "timestamp": datetime.now().isoformat()
88
- })
89
-
90
- json_data = {
91
- "session_info": {
92
- "date": datetime.now().strftime("%Y-%m-%d"),
93
- "participant": "PAR"
94
- },
95
- "utterances": utterances
96
- }
97
-
98
- return json_data
99
-
100
- def classify_text(self, json_data):
101
- """
102
- Classify the processed text using the fine-tuned BioBERT model
103
-
104
- Args:
105
- json_data: JSON structured data
106
-
107
- Returns:
108
- classification_results: Classification results in JSON format
109
- """
110
- if self.model is None:
111
- # Return mock results if model couldn't be loaded
112
- return {
113
- "prediction": "Mild Aphasia",
114
- "confidence": 0.85,
115
- "severity_score": 2,
116
- "analysis": {
117
- "total_utterances": len(json_data["utterances"]),
118
- "avg_utterance_length": sum(len(u["utterance"].split()) for u in json_data["utterances"]) / len(json_data["utterances"]) if json_data["utterances"] else 0,
119
- "linguistic_features": {
120
- "word_finding_difficulties": 0.3,
121
- "syntactic_complexity": 0.6,
122
- "semantic_appropriateness": 0.8
123
- }
124
- },
125
- "timestamp": datetime.now().isoformat(),
126
- "model_version": "BioBERT-Aphasia-v1.0"
127
- }
128
-
129
- # Combine all utterances for classification
130
- combined_text = " ".join([utterance["utterance"] for utterance in json_data["utterances"]])
131
-
132
- # Tokenize the input
133
- inputs = self.tokenizer(
134
- combined_text,
135
- return_tensors="pt",
136
- truncation=True,
137
- padding=True,
138
- max_length=512
139
- ).to(self.device)
140
-
141
- # Get prediction
142
- with torch.no_grad():
143
- outputs = self.model(**inputs)
144
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
145
- predicted_class = torch.argmax(predictions, dim=-1).item()
146
- confidence = torch.max(predictions).item()
147
-
148
- # Create detailed results
149
- results = {
150
- "prediction": self.severity_labels[predicted_class],
151
- "confidence": float(confidence),
152
- "severity_score": predicted_class,
153
- "class_probabilities": {
154
- label: float(prob) for label, prob in zip(self.severity_labels.values(), predictions[0].cpu().numpy())
155
- },
156
- "analysis": {
157
- "total_utterances": len(json_data["utterances"]),
158
- "total_words": len(combined_text.split()),
159
- "avg_utterance_length": sum(len(u["utterance"].split()) for u in json_data["utterances"]) / len(json_data["utterances"]) if json_data["utterances"] else 0
160
- },
161
- "timestamp": datetime.now().isoformat(),
162
- "model_version": "BioBERT-Aphasia-v1.0"
163
- }
164
-
165
- return results
166
-
167
- def process_pipeline(self, text_input):
168
- """
169
- Complete processing pipeline: text -> CHA -> JSON -> Classification -> Results
170
-
171
- Args:
172
- text_input: Raw text input
173
-
174
- Returns:
175
- tuple: (cha_formatted, json_data, classification_results, formatted_output)
176
- """
177
- # Step 1: Convert to CHA format
178
- cha_formatted = self.preprocess_to_cha(text_input)
179
-
180
- # Step 2: Convert CHA to JSON
181
- json_data = self.cha_to_json(cha_formatted)
182
-
183
- # Step 3: Classify using model
184
- classification_results = self.classify_text(json_data)
185
-
186
- # Step 4: Format output for display
187
- formatted_output = self.format_results(classification_results)
188
-
189
- return cha_formatted, json.dumps(json_data, indent=2), json.dumps(classification_results, indent=2), formatted_output
190
-
191
- def format_results(self, results):
192
- """
193
- Format results for user-friendly display
194
- """
195
- output = f"""
196
- # Aphasia Classification Results
197
 
198
- ## 🔍 **Prediction**: {results['prediction']}
199
- ## 📊 **Confidence**: {results['confidence']:.2%}
200
- ## 📈 **Severity Score**: {results['severity_score']}/3
 
 
 
 
201
 
202
- ### Detailed Analysis:
203
- - **Total Utterances**: {results['analysis']['total_utterances']}
204
- - **Total Words**: {results['analysis'].get('total_words', 'N/A')}
205
- - **Average Utterance Length**: {results['analysis']['avg_utterance_length']:.1f} words
206
-
207
- ### Class Probabilities:
208
- """
209
-
210
- if 'class_probabilities' in results:
211
- for class_name, prob in results['class_probabilities'].items():
212
- bar = "█" * int(prob * 20) # Simple progress bar
213
- output += f"- **{class_name}**: {prob:.2%} {bar}\n"
214
-
215
- output += f"\n*Analysis completed at: {results['timestamp']}*\n"
216
- output += f"*Model: {results['model_version']}*"
217
-
218
- return output
219
-
220
- # Initialize the classifier
221
- classifier = AphasiaClassifier()
222
-
223
- # Create Gradio interface
224
- def process_text(input_text):
225
- """
226
- Process text through the complete pipeline
227
- """
228
- if not input_text.strip():
229
- return "Please enter some text to analyze.", "", "", ""
230
-
231
- try:
232
- cha_formatted, json_data, classification_json, formatted_results = classifier.process_pipeline(input_text)
233
- return cha_formatted, json_data, classification_json, formatted_results
234
- except Exception as e:
235
- error_msg = f"Error processing text: {str(e)}"
236
- return error_msg, "", "", error_msg
237
-
238
- # Define the Gradio interface
239
- with gr.Blocks(title="Aphasia Classifier", theme=gr.themes.Soft()) as demo:
240
- gr.Markdown("""
241
- # 🧠 Aphasia Classification System
242
-
243
- This application uses a fine-tuned BioBERT model to classify speech patterns and identify potential aphasia severity levels.
244
-
245
- **Pipeline**: Text Input → CHA Format → JSON Structure → BioBERT Classification → Results
246
- """)
247
-
248
- with gr.Row():
249
- with gr.Column(scale=1):
250
- input_text = gr.Textbox(
251
- label="📝 Speech Input",
252
- placeholder="Enter the patient's speech sample here...\nExample: 'The boy is... uh... the boy is climbing the tree. No, wait. The tree... the boy goes up.'",
253
- lines=8,
254
- max_lines=20
255
- )
256
-
257
- classify_btn = gr.Button("🔍 Analyze Speech", variant="primary", size="lg")
258
-
259
- gr.Markdown("""
260
- ### 💡 Tips:
261
- - Enter natural speech samples
262
- - Include hesitations, repetitions, and corrections as they occur
263
- - Multiple sentences provide better analysis
264
- - The model analyzes linguistic patterns and fluency
265
- """)
266
-
267
- with gr.Column(scale=2):
268
- with gr.Tabs():
269
- with gr.TabItem("📊 Results"):
270
- formatted_output = gr.Markdown(
271
- label="Analysis Results",
272
- value="Enter text and click 'Analyze Speech' to see results here."
273
- )
274
-
275
- with gr.TabItem("📄 CHA Format"):
276
- cha_output = gr.Textbox(
277
- label="CHA Formatted Output",
278
- lines=6,
279
- interactive=False
280
- )
281
-
282
- with gr.TabItem("🔧 JSON Data"):
283
- json_output = gr.Textbox(
284
- label="Structured JSON Data",
285
- lines=8,
286
- interactive=False
287
- )
288
-
289
- with gr.TabItem("⚙️ Raw Classification"):
290
- classification_output = gr.Textbox(
291
- label="Raw Classification Results",
292
- lines=10,
293
- interactive=False
294
- )
295
-
296
- # Connect the button to the processing function
297
- classify_btn.click(
298
- fn=process_text,
299
- inputs=[input_text],
300
- outputs=[cha_output, json_output, classification_output, formatted_output]
301
- )
302
-
303
- # Example inputs
304
- gr.Examples(
305
- examples=[
306
- ["The boy is... uh... the boy is climbing the tree. No, wait. The tree... the boy goes up."],
307
- ["I want to... to go to the store. Buy some... what do you call it... bread. Yes, bread and milk."],
308
- ["The cat sat on the mat. It was a sunny day and the birds were singing in the trees."],
309
- ["Doctor, I feel... I feel not good. My head... it hurts here. Since yesterday."]
310
- ],
311
- inputs=[input_text]
312
- )
313
-
314
- gr.Markdown("""
315
- ---
316
- ### ⚠️ **Disclaimer**:
317
- This tool is for research and educational purposes only. It should not be used as a substitute for professional medical diagnosis or treatment. Always consult with qualified healthcare professionals for medical advice.
318
-
319
- ### 🔧 **Technical Details**:
320
- - **Model**: Fine-tuned BioBERT (dmis-lab/biobert-base-cased-v1.1)
321
- - **Input**: Natural language speech samples
322
- - **Output**: Severity classification (Normal, Mild, Moderate, Severe)
323
- - **Features**: CHA formatting, JSON structuring, confidence scores
324
- """)
325
-
326
- # Launch the app
327
  if __name__ == "__main__":
328
- demo.launch(
329
- server_name="0.0.0.0",
330
- server_port=7860,
331
- share=False, # Set to True if you want a public link
332
- debug=True
333
- )
 
1
  import gradio as gr
2
+ from pipeline import run_pipeline
 
 
 
 
 
 
3
 
4
+ def infer(file):
5
+ # file is a tempfile-like object from Gradio
6
+ return run_pipeline(file.name, out_style="json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ demo = gr.Interface(
9
+ fn=infer,
10
+ inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"), # accepts mp3/mp4/wav; ffmpeg handles it
11
+ outputs=gr.Code(label="Result (JSON)"),
12
+ title="Aphasia Classification",
13
+ description="Upload audio/video; pipeline: ffmpeg → .cha → JSON → model → result."
14
+ )
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  if __name__ == "__main__":
17
+ demo.launch()