Spaces:
Build error
Build error
| import os | |
| import json | |
| import requests | |
| import gradio as gr | |
| import pandas as pd | |
| from rapidfuzz import process, fuzz | |
| import datetime | |
| import openai # New import for OpenAI API | |
| # Set your OpenAI API key (make sure the environment variable is set) | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| # Load drug interaction data | |
| df_ddi = pd.read_csv("db_drug_interactions.csv") | |
| unique_drug_names = set(df_ddi["Drug 1"].dropna().unique()).union( | |
| set(df_ddi["Drug 2"].dropna().unique()) | |
| ) | |
| # File cache | |
| cache_file = "cache.json" | |
| if os.path.exists(cache_file): | |
| with open(cache_file, "r") as f: | |
| cache = json.load(f) | |
| else: | |
| cache = {} | |
| def save_cache(): | |
| with open(cache_file, "w") as f: | |
| json.dump(cache, f) | |
| # Template | |
| prompt_template = """ | |
| SYSTEM MESSAGE: | |
| You are a knowledgeable clinical decision-support assistant. You have access to the patient's medical chart, which contains details on diagnoses, medications, allergies, and other relevant clinical data. Additionally, you have verified drug interaction data from a curated dataset. Your role is to identify any notable drug interactions or side effects that may pose a risk. When providing explanations, include specific facts from the patient's chart and verified interaction data to substantiate your reasoning. If no significant interactions are found, state that none were identified. Provide this output in clear, clinical language, without making an actual medical diagnosis or prescribing decisions. | |
| USER MESSAGE: | |
| Below is the verified drug interaction data and the patient chart. Please review them carefully and identify: | |
| 1. Any major or moderate drug interactions that may pose a risk. | |
| 2. Important side effects or contraindications related to the patient's medications, especially given their comorbidities and allergies. | |
| 3. A concise explanation for each potential interaction or side effect, referencing the relevant portions of the patient chart and verified data. | |
| Verified Drug Interaction Data: | |
| {verified_data} | |
| Patient Chart: | |
| {patient_chart} | |
| INSTRUCTIONS: | |
| 1. Read the entire patient chart and verified data carefully. | |
| 2. Identify the patient's active medications, relevant conditions, allergies, and pertinent labs/diagnoses. | |
| 3. Cross-reference the verified interaction data above and incorporate it into your explanation. | |
| 4. List the drug interactions and a short one line explanation. No need to give lot of details.""" | |
| # Helper Functions | |
| # FHIR Helper | |
| BASE_URL = "http://hapi.fhir.org/baseR4" | |
| def calculate_age(birth_date_str): | |
| try: | |
| birth_date = datetime.datetime.strptime(birth_date_str, "%Y-%m-%d") | |
| today = datetime.datetime.today() | |
| age = today.year - birth_date.year - ((today.month, today.day) < (birth_date.month, birth_date.day)) | |
| return str(age) | |
| except Exception: | |
| return "" | |
| def deduplicate_list(items): | |
| seen = set() | |
| deduped = [] | |
| for item in items: | |
| if item not in seen: | |
| deduped.append(item) | |
| seen.add(item) | |
| return deduped | |
| def get_patient_chart(patient_id): | |
| patient_name = age = diagnoses = medications = allergies = lab_results = additional_notes = "" | |
| medications_list = [] | |
| patient_url = f"{BASE_URL}/Patient/{patient_id}" | |
| response = requests.get(patient_url) | |
| if response.status_code == 200: | |
| data = response.json() | |
| if "name" in data and data["name"]: | |
| name_entry = data["name"][0] | |
| given = " ".join(name_entry.get("given", [])) | |
| family = name_entry.get("family", "") | |
| patient_name = f"{given} {family}".strip() | |
| if "birthDate" in data: | |
| age = calculate_age(data["birthDate"]) | |
| if "note" in data: | |
| notes = [n.get("text", "") for n in data["note"]] | |
| additional_notes = " ".join(notes).strip() | |
| else: | |
| return None, [] | |
| condition_url = f"{BASE_URL}/Condition?patient={patient_id}" | |
| response = requests.get(condition_url) | |
| if response.status_code == 200 and "entry" in response.json(): | |
| condition_list = [] | |
| for e in response.json()["entry"]: | |
| code = e["resource"].get("code", {}) | |
| text = code.get("text", "") or code.get("coding", [{}])[0].get("display", "") | |
| if text: | |
| condition_list.append(text) | |
| diagnoses = ", ".join(deduplicate_list(condition_list)) | |
| med_url = f"{BASE_URL}/MedicationRequest?patient={patient_id}" | |
| response = requests.get(med_url) | |
| if response.status_code == 200 and "entry" in response.json(): | |
| med_list = [] | |
| for e in response.json()["entry"]: | |
| med = e.get("resource", {}) | |
| text = med.get("medicationCodeableConcept", {}).get("text", "") or \ | |
| med.get("medicationCodeableConcept", {}).get("coding", [{}])[0].get("display", "") | |
| if "dosageInstruction" in med: | |
| for d in med["dosageInstruction"]: | |
| if "text" in d: | |
| text += " " + d["text"] | |
| if text.strip(): | |
| med_list.append(text.strip()) | |
| medications_list = deduplicate_list(med_list) | |
| medications = ", ".join(medications_list) | |
| allergy_url = f"{BASE_URL}/AllergyIntolerance?patient={patient_id}" | |
| response = requests.get(allergy_url) | |
| if response.status_code == 200 and "entry" in response.json(): | |
| allergy_list = [] | |
| for e in response.json()["entry"]: | |
| code = e["resource"].get("code", {}) | |
| text = code.get("text", "") or code.get("coding", [{}])[0].get("display", "") | |
| if text: | |
| allergy_list.append(text) | |
| allergies = ", ".join(deduplicate_list(allergy_list)) | |
| if not allergies: | |
| allergies = "No known drug allergies" | |
| obs_url = f"{BASE_URL}/Observation?patient={patient_id}" | |
| response = requests.get(obs_url) | |
| if response.status_code == 200 and "entry" in response.json(): | |
| obs_list = [] | |
| for e in response.json()["entry"]: | |
| obs = e.get("resource", {}) | |
| code_text = obs.get("code", {}).get("text", "") or \ | |
| obs.get("code", {}).get("coding", [{}])[0].get("display", "") | |
| value = "" | |
| if "valueQuantity" in obs: | |
| q = obs["valueQuantity"] | |
| value = f"{q.get('value', '')} {q.get('unit', '')}".strip() | |
| elif "valueString" in obs: | |
| value = obs["valueString"] | |
| if code_text and value: | |
| obs_list.append(f"{code_text}: {value}") | |
| lab_results = ", ".join(deduplicate_list(obs_list)) | |
| patient_chart = (f"Patient Name: {patient_name}\n" | |
| f"Age: {age}\n" | |
| f"Diagnoses: {diagnoses}\n" | |
| f"Medications: {medications}\n" | |
| f"Allergies: {allergies}\n" | |
| f"Lab Results: {lab_results}\n" | |
| f"Additional Notes: {additional_notes}") | |
| return patient_chart, medications_list | |
| def get_best_drug_matches(query, drug_list, limit=1): | |
| matches = process.extract(query, drug_list, scorer=fuzz.partial_ratio, limit=limit) | |
| return matches | |
| def extract_medications(chart): | |
| line = next((l for l in chart.split("\n") if l.startswith("Medications:")), "") | |
| meds = [m.strip().split()[0] for m in line.replace("Medications:", "").split(",")] if line else [] | |
| return meds | |
| def find_ddi_for_meds(matched, df): | |
| meds = list(matched.values()) | |
| rows = [] | |
| for i in range(len(meds)): | |
| for j in range(i+1, len(meds)): | |
| a, b = meds[i], meds[j] | |
| match = df[((df["Drug 1"].str.lower() == a.lower()) & (df["Drug 2"].str.lower() == b.lower())) | | |
| ((df["Drug 1"].str.lower() == b.lower()) & (df["Drug 2"].str.lower() == a.lower()))] | |
| rows.extend(match.to_dict("records")) | |
| return rows | |
| def summarize_ddi_rows(rows): | |
| return "\n".join([f"- {r['Drug 1']} & {r['Drug 2']}: {r['Interaction Description']}" for r in rows]) \ | |
| or "No known major interactions found." | |
| # Stage 1: Fetch chart | |
| def fetch_chart(patient_id): | |
| chart, meds_list = get_patient_chart(patient_id) | |
| return chart, gr.update(visible=True), meds_list | |
| # Stage 2: Analyze interaction using OpenAI API | |
| def analyze_interaction(patient_chart, meds_list): | |
| matched = { | |
| m: (get_best_drug_matches(m, unique_drug_names)[0][0] | |
| if get_best_drug_matches(m, unique_drug_names)[0][1] >= 80 else m) | |
| for m in meds_list | |
| } | |
| ddi_rows = find_ddi_for_meds(matched, df_ddi) | |
| ddi_summary = summarize_ddi_rows(ddi_rows) | |
| final_prompt = prompt_template.format(verified_data=ddi_summary, patient_chart=patient_chart) | |
| print(f"Final Prompt :\n{final_prompt}") | |
| if final_prompt in cache: | |
| print("Fetching from cache..") | |
| yield cache[final_prompt] | |
| return | |
| debug = f""" | |
| ----- Extracted Meds ----- | |
| {meds_list} | |
| ----- Matches ----- | |
| {matched} | |
| ----- DDI Summary ----- | |
| {ddi_summary} | |
| """ | |
| response = "" | |
| # OpenAI streaming call using gpt-4o-mini | |
| stream = openai.ChatCompletion.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "user", "content": final_prompt}], | |
| stream=True | |
| ) | |
| for chunk in stream: | |
| # Extract streaming content in a manner similar to OpenAI's stream response format | |
| delta = chunk['choices'][0].get("delta", {}) | |
| content = delta.get("content", "") | |
| response += content | |
| # Yielding the debug information plus current response | |
| yield debug + response | |
| cache[final_prompt] = debug + response | |
| save_cache() | |
| # Gradio App (Default Layout) | |
| with gr.Blocks(title="Drug Interaction Assistant", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("## 🧪 Drug Interaction Checker") | |
| gr.Markdown("### Step 1: Extract Patient Chart") | |
| patient_input = gr.Textbox(label="Patient ID", placeholder="e.g. 46705085") | |
| fetch_btn = gr.Button("Fetch Patient Chart") | |
| chart_display = gr.Textbox(label="Patient Chart", lines=15, interactive=False, visible=True) | |
| chart_state = gr.State() | |
| gr.Markdown("### Step 2: Analyze Drug Interactions") | |
| analyze_btn = gr.Button("Run Interaction Analysis") | |
| analyze_output = gr.Textbox(label="LLM Risk Assessment", lines=20, interactive=False) | |
| fetch_btn.click(fn=fetch_chart, inputs=patient_input, outputs=[chart_display, chart_display, chart_state]) | |
| analyze_btn.click(fn=analyze_interaction, inputs=[chart_display, chart_state], outputs=analyze_output) | |
| if __name__ == "__main__": | |
| app.launch() | |