File size: 9,394 Bytes
01ab923
 
 
 
 
 
69c4dc7
01ab923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcd3c7b
01ab923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90e3caa
 
a7ae4c1
01ab923
69c4dc7
bcd3c7b
69c4dc7
01ab923
69c4dc7
 
 
 
 
 
 
 
 
bcd3c7b
 
 
69c4dc7
 
 
 
 
 
 
bcd3c7b
 
 
 
69c4dc7
 
 
 
 
 
4ae788a
69c4dc7
 
01ab923
 
4ae788a
01ab923
 
90e3caa
 
 
 
01ab923
 
 
 
 
 
 
 
90e3caa
 
 
 
 
4ae788a
 
 
 
90e3caa
 
01ab923
 
 
 
 
 
 
b4e60d5
01ab923
 
 
 
 
 
 
b4e60d5
01ab923
69c4dc7
a7ae4c1
01ab923
 
 
 
 
 
90e3caa
01ab923
90e3caa
01ab923
 
 
 
 
 
 
90e3caa
01ab923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b7ba6
01ab923
 
 
 
 
 
a7ae4c1
01ab923
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
import langextract as lx
import textwrap
import os
import uuid  # 用于生成唯一的文件名
import fitz  # 导入 PyMuPDF 库
import time  # 导入 time 模块用于等待

# --- 1. 定义 LangExtract 的提取逻辑 (保持不变) ---
def get_extraction_config():
    """
    准备 LangExtract 所需的 Prompt 和范例数据。
    """
    prompt_description = textwrap.dedent("""
        从文本中提取药物及其相关的详细信息,并使用属性对相关信息进行分组:
        1.  按照实体在文本中出现的顺序进行提取。
        2.  每个被提取的实体都必须有一个 'medication_group' 属性,用来将其与对应的药物关联起来。
        3.  关于同一种药物的所有详细信息(如剂量、频率)都应共享相同的 'medication_group' 值。
    """)
    examples = [
        lx.data.ExampleData(
            text="病人每日服用阿司匹林 100mg 以保护心脏健康,并在睡前服用辛伐他汀 20mg。",
            extractions=[
                lx.data.Extraction(extraction_class="药物", extraction_text="阿司匹林", attributes={"medication_group": "阿司匹林"}),
                lx.data.Extraction(extraction_class="剂量", extraction_text="100mg", attributes={"medication_group": "阿司匹林"}),
                lx.data.Extraction(extraction_class="频率", extraction_text="每日", attributes={"medication_group": "阿司匹林"}),
                lx.data.Extraction(extraction_class="病症", extraction_text="心脏健康", attributes={"medication_group": "阿司匹林"}),
                lx.data.Extraction(extraction_class="药物", extraction_text="辛伐他汀", attributes={"medication_group": "辛伐他汀"}),
                lx.data.Extraction(extraction_class="剂量", extraction_text="20mg", attributes={"medication_group": "辛伐他汀"}),
                lx.data.Extraction(extraction_class="频率", extraction_text="睡前", attributes={"medication_group": "辛伐他汀"}),
            ]
        )
    ]
    return prompt_description, examples

# --- 2. Gradio 的核心处理函数 (最终修复版) ---
def process_input_and_visualize(input_text, input_file):
    source_text = ""
    if input_file is not None:
        try:
            with fitz.open(input_file.name) as doc:
                for page in doc:
                    source_text += page.get_text()
        except Exception as e:
            raise gr.Error(f"PDF 文件解析失败: {e}")
    elif input_text and input_text.strip():
        source_text = input_text
    else:
        return None, "请输入文本或上传一个PDF文件...", None, None, None

    prompt, examples = get_extraction_config()
    api_key = os.environ.get("LANGEXTRACT_API_KEY")
    if not api_key:
        raise gr.Error("错误:服务器未配置 LANGEXTRACT_API_KEY。")

    max_retries = 3
    retry_delay = 5
    result = None

    for attempt in range(max_retries):
        try:
            print(f"Attempting to call Gemini API, attempt #{attempt + 1}")
            result = lx.extract(
                text_or_documents=source_text,
                prompt_description=prompt,
                examples=examples,
                model_id="gemini-2.5-flash",
                api_key=api_key,
                # --- 最终关键修复 ---
                # 将并行工作线程数设为 1,强制顺序处理,以避免超出免费套餐的速率限制。
                max_workers=1,
                extraction_passes=2,
                max_char_buffer=1500
            )
            print("API call successful.")
            break
        except Exception as e:
            print(f"Attempt #{attempt + 1} failed with error: {e}")
            # 同时处理 503 (服务器超载) 和 429 (速率限制) 错误
            if ("503" in str(e) or "429" in str(e)) and attempt < max_retries - 1:
                error_type = "503" if "503" in str(e) else "429"
                print(f"API Error ({error_type}). Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
            else:
                raise gr.Error(f"信息提取过程中发生错误: {e}")
    
    if result is None:
        raise gr.Error("所有重试均失败,无法从API获取结果。请稍后再试。")

    # --- 后续处理代码保持不变 ---
    grounded_extractions = [e for e in result.extractions if e.char_interval]
    highlighted_text = []
    last_pos = 0
    sorted_extractions = sorted(grounded_extractions, key=lambda e: e.char_interval.start_pos)
    for entity in sorted_extractions:
        start, end = entity.char_interval.start_pos, entity.char_interval.end_pos
        if start >= last_pos and end <= len(source_text):
            highlighted_text.append((source_text[last_pos:start], None))
            highlighted_text.append((entity.extraction_text, entity.extraction_class))
            last_pos = end
    highlighted_text.append((source_text[last_pos:], None))
    
    medication_groups = {}
    for extraction in result.extractions:
        group_name = extraction.attributes.get("medication_group", "未分组")
        medication_groups.setdefault(group_name, []).append(extraction)

    structured_output = "### 结构化提取结果\n\n"
    if not medication_groups:
        structured_output += "未提取到任何药物信息。"
    else:
        for med_name, extractions in medication_groups.items():
            structured_output += f"#### 药物组: {med_name}\n"
            for extraction in sorted(extractions, key=lambda e: e.char_interval.start_pos if e.char_interval else -1):
                pos_info = ""
                if extraction.char_interval:
                    pos_info = f" (位置: {extraction.char_interval.start_pos}-{extraction.char_interval.end_pos})"
                structured_output += f"- **{extraction.extraction_class}**: {extraction.extraction_text}{pos_info}\n"
            structured_output += "\n"

    session_id = str(uuid.uuid4())
    output_dir = "/tmp"
    os.makedirs(output_dir, exist_ok=True)
    jsonl_filename = f"extraction_{session_id}.jsonl"
    jsonl_path = os.path.join(output_dir, jsonl_filename)
    html_path = os.path.join(output_dir, f"visualization_{session_id}.html")
    lx.io.save_annotated_documents([result], output_name=jsonl_filename, output_dir=output_dir)
    html_content = lx.visualize(jsonl_path)
    with open(html_path, "w", encoding="utf-8") as f:
        f.write(html_content)

    return highlighted_text, structured_output, html_path, html_path, jsonl_path


# --- 3. 创建 Gradio 应用界面 (保持不变) ---
with gr.Blocks(theme=gr.themes.Soft(), title="药物信息提取器") as demo:
    # ... (界面部分代码无需修改) ...
    gr.Markdown("# ⚕️ LangExtract 药物信息提取器")
    gr.Markdown("一个基于大型语言模型的智能工具,可从**临床文本**或 **PDF 文件**中自动提取药物、剂量等信息,并进行结构化关联。")
    
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Tabs():
                with gr.TabItem("📄 临床文本输入"):
                    input_textbox = gr.Textbox(lines=15, label="粘贴临床笔记", placeholder="请在此处粘贴文本...")
                with gr.TabItem("📁 上传PDF文件"):
                    input_file_uploader = gr.File(label="选择一个 PDF 文件进行分析", file_types=['.pdf'])
            
            submit_btn = gr.Button("🧠 提取信息", variant="primary")
            gr.Examples(
                examples=[
                    "The patient was prescribed Lisinopril and Metformin last month.\nHe takes the Lisinopril 10mg daily for hypertension, but often misses his Metformin 500mg dose which should be taken twice daily for diabetes.",
                    "Patient took 400 mg PO Ibuprofen q4h for two days for a headache.",
                ],
                inputs=input_textbox,
                label="示例文本 (点击自动填充到上方文本框)"
            )

        with gr.Column(scale=2):
            with gr.Tabs():
                with gr.TabItem("📊 总览 (NER & RE)"):
                    gr.Markdown("### 命名实体识别 (NER) - 文本高亮")
                    output_highlight = gr.HighlightedText(label="实体高亮显示", color_map={"药物": "#FF6347", "剂量": "#FFA500", "频率": "#32CD32", "病症": "#4169E1"})
                    output_structured = gr.Markdown(label="结构化关系")
                with gr.TabItem("🌐 交互式可视化"):
                    gr.Markdown("### 交互式可视化图表")
                    output_html_viewer = gr.HTML(label="交互式图表 (可缩放和筛选)")
                with gr.TabItem("📁 文件下载"):
                    gr.Markdown("### 下载提取结果")
                    download_html = gr.File(label="下载交互式 HTML 文件")
                    download_jsonl = gr.File(label="下载 JSONL 数据文件")

    submit_btn.click(
        fn=process_input_and_visualize,
        inputs=[input_textbox, input_file_uploader],
        outputs=[output_highlight, output_structured, output_html_viewer, download_html, download_jsonl]
    ).then(
        lambda: (None, None), 
        inputs=None, 
        outputs=[input_textbox, input_file_uploader]
    )


if __name__ == "__main__":
    demo.launch()