Ali2206 commited on
Commit
8955687
Β·
verified Β·
1 Parent(s): 8482afb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -215
app.py CHANGED
@@ -60,57 +60,17 @@ def remove_duplicate_paragraphs(text: str) -> str:
60
  seen.add(clean_p)
61
  return "\n\n".join(unique_paragraphs)
62
 
63
- def extract_text_from_excel(path: str) -> str:
64
- all_text = []
65
- xls = pd.ExcelFile(path)
66
- for sheet_name in xls.sheet_names:
67
- try:
68
- df = xls.parse(sheet_name).astype(str).fillna("")
69
- except Exception:
70
- continue
71
- for _, row in df.iterrows():
72
- non_empty = [cell.strip() for cell in row if cell.strip()]
73
- if len(non_empty) >= 2:
74
- text_line = " | ".join(non_empty)
75
- if len(text_line) > 15:
76
- all_text.append(f"[{sheet_name}] {text_line}")
77
- return "\n".join(all_text)
78
-
79
- def extract_text_from_csv(path: str) -> str:
80
- all_text = []
81
- try:
82
- df = pd.read_csv(path).astype(str).fillna("")
83
- except Exception:
84
- return ""
85
- for _, row in df.iterrows():
86
- non_empty = [cell.strip() for cell in row if cell.strip()]
87
- if len(non_empty) >= 2:
88
- text_line = " | ".join(non_empty)
89
- if len(text_line) > 15:
90
- all_text.append(text_line)
91
- return "\n".join(all_text)
92
-
93
- def extract_text_from_pdf(path: str) -> str:
94
- import logging
95
- logging.getLogger("pdfminer").setLevel(logging.ERROR)
96
- all_text = []
97
- try:
98
- with pdfplumber.open(path) as pdf:
99
- for page in pdf.pages:
100
- text = page.extract_text()
101
- if text:
102
- all_text.append(text.strip())
103
- except Exception:
104
- return ""
105
- return "\n".join(all_text)
106
-
107
  def extract_text(file_path: str) -> str:
108
  if file_path.endswith(".xlsx"):
109
- return extract_text_from_excel(file_path)
110
  elif file_path.endswith(".csv"):
111
- return extract_text_from_csv(file_path)
112
  elif file_path.endswith(".pdf"):
113
- return extract_text_from_pdf(file_path)
 
 
 
 
114
  else:
115
  return ""
116
 
@@ -136,6 +96,9 @@ def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[s
136
  def build_prompt(chunk: str) -> str:
137
  return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
138
 
 
 
 
139
  def init_agent() -> TxAgent:
140
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
141
  if not os.path.exists(tool_path):
@@ -224,189 +187,35 @@ Avoid repeating the same points multiple times.
224
  final_response = remove_duplicate_paragraphs(final_response)
225
  return final_response
226
 
227
- def remove_non_ascii(text):
228
- return ''.join(c for c in text if ord(c) < 256)
229
-
230
- def generate_pdf_report_with_charts(summary: str, report_path: str, detailed_batches: List[str] = None):
231
- chart_dir = os.path.join(os.path.dirname(report_path), "charts")
232
- os.makedirs(chart_dir, exist_ok=True)
233
-
234
- # Prepare data
235
- categories = ['Diagnostics', 'Medications', 'Missed', 'Inconsistencies', 'Follow-up']
236
- values = [4, 2, 3, 1, 5]
237
-
238
- # Chart 1: Bar
239
- bar_chart_path = os.path.join(chart_dir, "bar_chart.png")
240
- plt.figure(figsize=(6, 4))
241
- plt.bar(categories, values)
242
- plt.title('Clinical Issues Overview')
243
- plt.tight_layout()
244
- plt.savefig(bar_chart_path)
245
- plt.close()
246
-
247
- # Chart 2: Pie
248
- pie_chart_path = os.path.join(chart_dir, "pie_chart.png")
249
- plt.figure(figsize=(6, 6))
250
- plt.pie(values, labels=categories, autopct='%1.1f%%')
251
- plt.title('Issue Distribution')
252
- plt.tight_layout()
253
- plt.savefig(pie_chart_path)
254
- plt.close()
255
-
256
- # Chart 3: Line
257
- trend_chart_path = os.path.join(chart_dir, "trend_chart.png")
258
- plt.figure(figsize=(6, 4))
259
- plt.plot(categories, values, marker='o')
260
- plt.title('Trend Analysis')
261
- plt.tight_layout()
262
- plt.savefig(trend_chart_path)
263
- plt.close()
264
-
265
- # PDF init
266
- pdf_path = report_path.replace('.md', '.pdf')
267
- pdf = FPDF()
268
- pdf.set_auto_page_break(auto=True, margin=15)
269
-
270
- # === Title Page ===
271
- pdf.add_page()
272
- pdf.set_font("Arial", 'B', 24)
273
- pdf.cell(0, 20, remove_non_ascii("Final Medical Report"), ln=True, align='C')
274
- pdf.set_font("Arial", '', 14)
275
- pdf.cell(0, 10, datetime.now().strftime("Generated on %B %d, %Y at %H:%M"), ln=True, align='C')
276
- pdf.ln(20)
277
- pdf.set_font("Arial", 'I', 12)
278
- pdf.multi_cell(0, 10, remove_non_ascii(
279
- "This report contains a professional summary of clinical observations, potential inconsistencies, and follow-up recommendations based on the uploaded medical document."
280
- ), align="C")
281
-
282
- # === Summary Section ===
283
- pdf.add_page()
284
- pdf.set_font("Arial", 'B', 16)
285
- pdf.cell(0, 10, remove_non_ascii("Final Summary"), ln=True)
286
- pdf.set_draw_color(200, 200, 200)
287
- pdf.line(10, pdf.get_y(), 200, pdf.get_y())
288
- pdf.ln(5)
289
- pdf.set_font("Arial", '', 12)
290
- for line in summary.split("\n"):
291
- clean_line = remove_non_ascii(line.strip())
292
- if clean_line:
293
- pdf.multi_cell(0, 8, txt=clean_line)
294
-
295
- # === Charts Section ===
296
- pdf.add_page()
297
- pdf.set_font("Arial", 'B', 16)
298
- pdf.cell(0, 10, remove_non_ascii("Statistical Overview"), ln=True)
299
- pdf.line(10, pdf.get_y(), 200, pdf.get_y())
300
- pdf.ln(5)
301
-
302
- pdf.set_font("Arial", 'B', 12)
303
- pdf.cell(0, 10, remove_non_ascii("1. Clinical Issues Overview"), ln=True)
304
- pdf.image(bar_chart_path, w=180)
305
- pdf.ln(5)
306
-
307
- pdf.cell(0, 10, remove_non_ascii("2. Issue Distribution"), ln=True)
308
- pdf.image(pie_chart_path, w=150)
309
- pdf.ln(5)
310
-
311
- pdf.cell(0, 10, remove_non_ascii("3. Trend Analysis"), ln=True)
312
- pdf.image(trend_chart_path, w=180)
313
-
314
- # === Detailed Tool Outputs ===
315
- if detailed_batches:
316
- pdf.add_page()
317
- pdf.set_font("Arial", 'B', 16)
318
- pdf.cell(0, 10, remove_non_ascii("Detailed Tool Insights"), ln=True)
319
- pdf.line(10, pdf.get_y(), 200, pdf.get_y())
320
- pdf.ln(5)
321
-
322
- for idx, detail in enumerate(detailed_batches):
323
- pdf.set_font("Arial", 'B', 13)
324
- pdf.cell(0, 10, remove_non_ascii(f"Tool Output #{idx + 1}"), ln=True)
325
- pdf.set_font("Arial", '', 11)
326
- for line in remove_non_ascii(detail).split("\n"):
327
- pdf.multi_cell(0, 8, txt=line.strip())
328
- pdf.ln(3)
329
-
330
- pdf.output(pdf_path)
331
- return pdf_path
332
-
333
- def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
334
  if not file or not hasattr(file, "name"):
335
- messages.append({"role": "assistant", "content": "❌ Please upload a valid file."})
336
- return messages, None
337
-
338
- start_time = time.time()
339
- messages.append({"role": "user", "content": f"πŸ“‚ Processing file: {os.path.basename(file.name)}"})
340
-
341
  try:
342
  extracted = extract_text(file.name)
343
  if not extracted:
344
- messages.append({"role": "assistant", "content": "❌ Could not extract text."})
345
- return messages, None
346
 
347
  chunks = split_text(extracted)
348
  batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
349
- messages.append({"role": "assistant", "content": f"πŸ” Split into {len(batches)} batches. Analyzing..."})
350
-
351
  batch_results = analyze_batches(agent, batches)
352
- all_tool_outputs = batch_results.copy()
353
  valid = [res for res in batch_results if not res.startswith("❌")]
354
 
355
  if not valid:
356
- messages.append({"role": "assistant", "content": "❌ No valid batch outputs."})
357
- return messages, None
358
 
359
  summary = generate_final_summary(agent, "\n\n".join(valid))
360
-
361
- report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
362
  with open(report_path, 'w', encoding='utf-8') as f:
363
- f.write(f"# Final Medical Report\n\n{summary}")
364
-
365
- pdf_path = generate_pdf_report_with_charts(summary, report_path, detailed_batches=all_tool_outputs)
366
-
367
- end_time = time.time()
368
- elapsed_time = end_time - start_time
369
-
370
- messages.append({"role": "assistant", "content": f"πŸ“Š **Final Report:**\n\n{summary}"})
371
- messages.append({"role": "assistant", "content": f"βœ… Report generated in **{elapsed_time:.2f} seconds**.\n\nπŸ“₯ PDF report ready: {os.path.basename(pdf_path)}"})
372
-
373
- return messages, pdf_path
374
-
375
  except Exception as e:
376
- messages.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
377
- return messages, None
378
-
379
- def create_ui(agent):
380
- with gr.Blocks(css="""
381
- html, body, .gradio-container { background: #0e1621; color: #e0e0e0; padding: 16px; }
382
- button.svelte-1ipelgc { background: linear-gradient(to right, #1e88e5, #0d47a1) !important; border: 1px solid #0d47a1 !important; color: white !important; font-weight: bold !important; padding: 10px 20px !important; border-radius: 8px !important; }
383
- button.svelte-1ipelgc:hover { background: linear-gradient(to right, #2196f3, #1565c0) !important; border: 1px solid #1565c0 !important; }
384
- .gr-column { align-items: center !important; gap: 12px; }
385
- .gr-file, .gr-button { width: 100% !important; max-width: 400px; }
386
- """) as demo:
387
- gr.Markdown("""
388
- <h2 style='text-align:center;'>πŸ“„ CPS: Clinical Patient Support System</h2>
389
- <p style='text-align:center;'>Analyze and summarize unstructured medical files using AI (optimized for A100 GPU).</p>
390
- """)
391
-
392
- with gr.Column():
393
- chatbot = gr.Chatbot(label="🧠 CPS Assistant", height=480, type="messages")
394
- upload = gr.File(label="πŸ“‚ Upload Medical File", file_types=[".xlsx", ".csv", ".pdf"])
395
- analyze = gr.Button("🧠 Analyze")
396
- download = gr.File(label="πŸ“₯ Download Report", visible=False, interactive=False)
397
-
398
- state = gr.State(value=[])
399
-
400
- def handle_analysis(file, chat):
401
- messages, report_path = process_report(agent, file, chat)
402
- return messages, gr.update(visible=bool(report_path), value=report_path), messages
403
-
404
- analyze.click(fn=handle_analysis, inputs=[upload, state], outputs=[chatbot, download, state])
405
-
406
- return demo
407
 
408
  if __name__ == "__main__":
409
  agent = init_agent()
410
- ui = create_ui(agent)
411
- ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
412
-
 
 
 
60
  seen.add(clean_p)
61
  return "\n\n".join(unique_paragraphs)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def extract_text(file_path: str) -> str:
64
  if file_path.endswith(".xlsx"):
65
+ return pd.read_excel(file_path).astype(str).fillna("").to_string(index=False)
66
  elif file_path.endswith(".csv"):
67
+ return pd.read_csv(file_path).astype(str).fillna("").to_string(index=False)
68
  elif file_path.endswith(".pdf"):
69
+ try:
70
+ with pdfplumber.open(file_path) as pdf:
71
+ return "\n".join(page.extract_text() or '' for page in pdf.pages)
72
+ except Exception:
73
+ return ""
74
  else:
75
  return ""
76
 
 
96
  def build_prompt(chunk: str) -> str:
97
  return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
98
 
99
+ def remove_non_ascii(text):
100
+ return ''.join(c for c in text if ord(c) < 256)
101
+
102
  def init_agent() -> TxAgent:
103
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
104
  if not os.path.exists(tool_path):
 
187
  final_response = remove_duplicate_paragraphs(final_response)
188
  return final_response
189
 
190
+ def handle_analysis(file):
191
+ messages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  if not file or not hasattr(file, "name"):
193
+ return "❌ Please upload a valid file.", None
 
 
 
 
 
194
  try:
195
  extracted = extract_text(file.name)
196
  if not extracted:
197
+ return "❌ Could not extract text.", None
 
198
 
199
  chunks = split_text(extracted)
200
  batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
 
 
201
  batch_results = analyze_batches(agent, batches)
 
202
  valid = [res for res in batch_results if not res.startswith("❌")]
203
 
204
  if not valid:
205
+ return "❌ No valid batch outputs.", None
 
206
 
207
  summary = generate_final_summary(agent, "\n\n".join(valid))
208
+ report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
 
209
  with open(report_path, 'w', encoding='utf-8') as f:
210
+ f.write(summary)
211
+ return summary, report_path
 
 
 
 
 
 
 
 
 
 
212
  except Exception as e:
213
+ return f"❌ Error: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  if __name__ == "__main__":
216
  agent = init_agent()
217
+ gr.Interface(
218
+ fn=handle_analysis,
219
+ inputs=gr.File(file_types=[".pdf", ".csv", ".xlsx"]),
220
+ outputs=[gr.Textbox(label="Summary"), gr.File(label="Download Report")]
221
+ ).queue().launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)