pharmextract / app.py
leonsimon23's picture
Update app.py
bcd3c7b verified
import gradio as gr
import langextract as lx
import textwrap
import os
import uuid # 用于生成唯一的文件名
import fitz # 导入 PyMuPDF 库
import time # 导入 time 模块用于等待
# --- 1. 定义 LangExtract 的提取逻辑 (保持不变) ---
def get_extraction_config():
"""
准备 LangExtract 所需的 Prompt 和范例数据。
"""
prompt_description = textwrap.dedent("""
从文本中提取药物及其相关的详细信息,并使用属性对相关信息进行分组:
1. 按照实体在文本中出现的顺序进行提取。
2. 每个被提取的实体都必须有一个 'medication_group' 属性,用来将其与对应的药物关联起来。
3. 关于同一种药物的所有详细信息(如剂量、频率)都应共享相同的 'medication_group' 值。
""")
examples = [
lx.data.ExampleData(
text="病人每日服用阿司匹林 100mg 以保护心脏健康,并在睡前服用辛伐他汀 20mg。",
extractions=[
lx.data.Extraction(extraction_class="药物", extraction_text="阿司匹林", attributes={"medication_group": "阿司匹林"}),
lx.data.Extraction(extraction_class="剂量", extraction_text="100mg", attributes={"medication_group": "阿司匹林"}),
lx.data.Extraction(extraction_class="频率", extraction_text="每日", attributes={"medication_group": "阿司匹林"}),
lx.data.Extraction(extraction_class="病症", extraction_text="心脏健康", attributes={"medication_group": "阿司匹林"}),
lx.data.Extraction(extraction_class="药物", extraction_text="辛伐他汀", attributes={"medication_group": "辛伐他汀"}),
lx.data.Extraction(extraction_class="剂量", extraction_text="20mg", attributes={"medication_group": "辛伐他汀"}),
lx.data.Extraction(extraction_class="频率", extraction_text="睡前", attributes={"medication_group": "辛伐他汀"}),
]
)
]
return prompt_description, examples
# --- 2. Gradio 的核心处理函数 (最终修复版) ---
def process_input_and_visualize(input_text, input_file):
source_text = ""
if input_file is not None:
try:
with fitz.open(input_file.name) as doc:
for page in doc:
source_text += page.get_text()
except Exception as e:
raise gr.Error(f"PDF 文件解析失败: {e}")
elif input_text and input_text.strip():
source_text = input_text
else:
return None, "请输入文本或上传一个PDF文件...", None, None, None
prompt, examples = get_extraction_config()
api_key = os.environ.get("LANGEXTRACT_API_KEY")
if not api_key:
raise gr.Error("错误:服务器未配置 LANGEXTRACT_API_KEY。")
max_retries = 3
retry_delay = 5
result = None
for attempt in range(max_retries):
try:
print(f"Attempting to call Gemini API, attempt #{attempt + 1}")
result = lx.extract(
text_or_documents=source_text,
prompt_description=prompt,
examples=examples,
model_id="gemini-2.5-flash",
api_key=api_key,
# --- 最终关键修复 ---
# 将并行工作线程数设为 1,强制顺序处理,以避免超出免费套餐的速率限制。
max_workers=1,
extraction_passes=2,
max_char_buffer=1500
)
print("API call successful.")
break
except Exception as e:
print(f"Attempt #{attempt + 1} failed with error: {e}")
# 同时处理 503 (服务器超载) 和 429 (速率限制) 错误
if ("503" in str(e) or "429" in str(e)) and attempt < max_retries - 1:
error_type = "503" if "503" in str(e) else "429"
print(f"API Error ({error_type}). Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
raise gr.Error(f"信息提取过程中发生错误: {e}")
if result is None:
raise gr.Error("所有重试均失败,无法从API获取结果。请稍后再试。")
# --- 后续处理代码保持不变 ---
grounded_extractions = [e for e in result.extractions if e.char_interval]
highlighted_text = []
last_pos = 0
sorted_extractions = sorted(grounded_extractions, key=lambda e: e.char_interval.start_pos)
for entity in sorted_extractions:
start, end = entity.char_interval.start_pos, entity.char_interval.end_pos
if start >= last_pos and end <= len(source_text):
highlighted_text.append((source_text[last_pos:start], None))
highlighted_text.append((entity.extraction_text, entity.extraction_class))
last_pos = end
highlighted_text.append((source_text[last_pos:], None))
medication_groups = {}
for extraction in result.extractions:
group_name = extraction.attributes.get("medication_group", "未分组")
medication_groups.setdefault(group_name, []).append(extraction)
structured_output = "### 结构化提取结果\n\n"
if not medication_groups:
structured_output += "未提取到任何药物信息。"
else:
for med_name, extractions in medication_groups.items():
structured_output += f"#### 药物组: {med_name}\n"
for extraction in sorted(extractions, key=lambda e: e.char_interval.start_pos if e.char_interval else -1):
pos_info = ""
if extraction.char_interval:
pos_info = f" (位置: {extraction.char_interval.start_pos}-{extraction.char_interval.end_pos})"
structured_output += f"- **{extraction.extraction_class}**: {extraction.extraction_text}{pos_info}\n"
structured_output += "\n"
session_id = str(uuid.uuid4())
output_dir = "/tmp"
os.makedirs(output_dir, exist_ok=True)
jsonl_filename = f"extraction_{session_id}.jsonl"
jsonl_path = os.path.join(output_dir, jsonl_filename)
html_path = os.path.join(output_dir, f"visualization_{session_id}.html")
lx.io.save_annotated_documents([result], output_name=jsonl_filename, output_dir=output_dir)
html_content = lx.visualize(jsonl_path)
with open(html_path, "w", encoding="utf-8") as f:
f.write(html_content)
return highlighted_text, structured_output, html_path, html_path, jsonl_path
# --- 3. 创建 Gradio 应用界面 (保持不变) ---
with gr.Blocks(theme=gr.themes.Soft(), title="药物信息提取器") as demo:
# ... (界面部分代码无需修改) ...
gr.Markdown("# ⚕️ LangExtract 药物信息提取器")
gr.Markdown("一个基于大型语言模型的智能工具,可从**临床文本**或 **PDF 文件**中自动提取药物、剂量等信息,并进行结构化关联。")
with gr.Row():
with gr.Column(scale=1):
with gr.Tabs():
with gr.TabItem("📄 临床文本输入"):
input_textbox = gr.Textbox(lines=15, label="粘贴临床笔记", placeholder="请在此处粘贴文本...")
with gr.TabItem("📁 上传PDF文件"):
input_file_uploader = gr.File(label="选择一个 PDF 文件进行分析", file_types=['.pdf'])
submit_btn = gr.Button("🧠 提取信息", variant="primary")
gr.Examples(
examples=[
"The patient was prescribed Lisinopril and Metformin last month.\nHe takes the Lisinopril 10mg daily for hypertension, but often misses his Metformin 500mg dose which should be taken twice daily for diabetes.",
"Patient took 400 mg PO Ibuprofen q4h for two days for a headache.",
],
inputs=input_textbox,
label="示例文本 (点击自动填充到上方文本框)"
)
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("📊 总览 (NER & RE)"):
gr.Markdown("### 命名实体识别 (NER) - 文本高亮")
output_highlight = gr.HighlightedText(label="实体高亮显示", color_map={"药物": "#FF6347", "剂量": "#FFA500", "频率": "#32CD32", "病症": "#4169E1"})
output_structured = gr.Markdown(label="结构化关系")
with gr.TabItem("🌐 交互式可视化"):
gr.Markdown("### 交互式可视化图表")
output_html_viewer = gr.HTML(label="交互式图表 (可缩放和筛选)")
with gr.TabItem("📁 文件下载"):
gr.Markdown("### 下载提取结果")
download_html = gr.File(label="下载交互式 HTML 文件")
download_jsonl = gr.File(label="下载 JSONL 数据文件")
submit_btn.click(
fn=process_input_and_visualize,
inputs=[input_textbox, input_file_uploader],
outputs=[output_highlight, output_structured, output_html_viewer, download_html, download_jsonl]
).then(
lambda: (None, None),
inputs=None,
outputs=[input_textbox, input_file_uploader]
)
if __name__ == "__main__":
demo.launch()