Spaces:
Paused
Paused
Ankit Thakur commited on
Commit ·
c57fdf3
1
Parent(s): bf85b2f
everything
Browse files- app.py +52 -99
- config.py +21 -22
- requirements.txt +11 -11
- validate_prescription.py +87 -286
app.py
CHANGED
|
@@ -1,137 +1,90 @@
|
|
| 1 |
import os
|
| 2 |
import streamlit as st
|
| 3 |
-
from
|
| 4 |
-
from config import (
|
| 5 |
-
STATIC_DIR,
|
| 6 |
-
UPLOADS_DIR,
|
| 7 |
-
HF_TOKEN,
|
| 8 |
-
GOOGLE_API_KEY,
|
| 9 |
-
GOOGLE_CSE_ID,
|
| 10 |
-
GEMINI_API_KEY,
|
| 11 |
-
DEVICE
|
| 12 |
-
)
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
st.set_page_config(
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
'Get Help': 'https://github.com/your-repo',
|
| 21 |
-
'About': "RxGuard v1.0 - Advanced Prescription Validation"
|
| 22 |
-
}
|
| 23 |
-
)
|
| 24 |
|
| 25 |
-
# ─── Session State ────────────────────────────────────────────────────────
|
| 26 |
if "analysis_result" not in st.session_state:
|
| 27 |
st.session_state.analysis_result = None
|
| 28 |
if "uploaded_filename" not in st.session_state:
|
| 29 |
st.session_state.uploaded_filename = None
|
| 30 |
|
| 31 |
-
# ─── UI Components ────────────────────────────────────────────────────────
|
| 32 |
def show_service_status():
|
| 33 |
"""Displays service connectivity status."""
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
st.metric("Gemini", "✅" if GEMINI_API_KEY else "❌")
|
| 41 |
-
with cols[3]:
|
| 42 |
-
st.metric("Device", DEVICE.upper())
|
| 43 |
-
|
| 44 |
-
def display_patient_info(info: dict):
|
| 45 |
-
"""Displays patient information in a formatted card."""
|
| 46 |
-
with st.container(border=True):
|
| 47 |
-
st.subheader("👤 Patient Details")
|
| 48 |
-
cols = st.columns(2)
|
| 49 |
-
with cols[0]:
|
| 50 |
-
st.markdown(f"**Name:** {info.get('Name', 'Not detected')}")
|
| 51 |
-
st.markdown(f"**Age:** {info.get('Age', 'N/A')}")
|
| 52 |
-
with cols[1]:
|
| 53 |
-
st.markdown(f"**Date:** {info.get('Date', 'N/A')}")
|
| 54 |
-
st.markdown(f"**Physician:** {info.get('PhysicianName', 'N/A')}")
|
| 55 |
-
|
| 56 |
-
def display_medications(medications: list):
|
| 57 |
-
"""Displays medication information with verification."""
|
| 58 |
-
st.subheader("💊 Medications")
|
| 59 |
-
if not medications:
|
| 60 |
-
st.warning("No medications detected in prescription")
|
| 61 |
-
return
|
| 62 |
-
|
| 63 |
-
for med in medications:
|
| 64 |
-
with st.expander(f"{med.get('drug_raw', 'Unknown Medication')}"):
|
| 65 |
-
cols = st.columns([1, 2])
|
| 66 |
-
with cols[0]:
|
| 67 |
-
st.markdown(f"""
|
| 68 |
-
**Dosage:** `{med.get('dosage', 'N/A')}`
|
| 69 |
-
**Frequency:** `{med.get('frequency', 'N/A')}`
|
| 70 |
-
""")
|
| 71 |
-
|
| 72 |
-
with cols[1]:
|
| 73 |
-
if verification := med.get("verification"):
|
| 74 |
-
if dosage := verification.get("standard_dosage"):
|
| 75 |
-
st.success(f"**Standard Dosage:** {dosage}")
|
| 76 |
-
if side_effects := verification.get("side_effects"):
|
| 77 |
-
st.warning(f"**Side Effects:** {side_effects}")
|
| 78 |
-
if interactions := verification.get("interactions"):
|
| 79 |
-
st.error(f"**Interactions:** {interactions}")
|
| 80 |
|
| 81 |
-
# ─── Main Application ─────────────────────────────────────────────────────
|
| 82 |
def main():
|
| 83 |
st.title("⚕️ RxGuard Prescription Validator")
|
| 84 |
-
st.caption("
|
| 85 |
-
|
| 86 |
show_service_status()
|
| 87 |
-
|
| 88 |
# Only enable upload if required services are available
|
| 89 |
-
if all([HF_TOKEN, GOOGLE_API_KEY
|
| 90 |
uploaded_file = st.file_uploader(
|
| 91 |
-
"Upload prescription image (PNG/JPG/JPEG):",
|
| 92 |
type=["png", "jpg", "jpeg"],
|
| 93 |
-
help="
|
| 94 |
)
|
| 95 |
-
|
| 96 |
if uploaded_file and uploaded_file.name != st.session_state.uploaded_filename:
|
| 97 |
with st.status("Analyzing prescription...", expanded=True) as status:
|
| 98 |
try:
|
| 99 |
-
# Store the uploaded file
|
| 100 |
st.session_state.uploaded_filename = uploaded_file.name
|
| 101 |
file_path = os.path.join(UPLOADS_DIR, uploaded_file.name)
|
| 102 |
-
|
| 103 |
with open(file_path, "wb") as f:
|
| 104 |
f.write(uploaded_file.getvalue())
|
| 105 |
-
|
| 106 |
-
#
|
| 107 |
from validate_prescription import extract_prescription_info
|
| 108 |
st.session_state.analysis_result = extract_prescription_info(file_path)
|
| 109 |
-
|
| 110 |
status.update(label="Analysis complete!", state="complete", expanded=False)
|
| 111 |
except Exception as e:
|
| 112 |
-
st.error(f"
|
| 113 |
st.session_state.analysis_result = {"error": str(e)}
|
| 114 |
status.update(label="Analysis failed", state="error")
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
| 122 |
else:
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
with tab1:
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
with tab2:
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
if st.toggle("Show technical details"):
|
| 134 |
-
st.json(result.get("debug_info", {}))
|
| 135 |
|
| 136 |
if __name__ == "__main__":
|
| 137 |
main()
|
|
|
|
| 1 |
import os
|
| 2 |
import streamlit as st
|
| 3 |
+
from config import STATIC_DIR, HF_TOKEN, GOOGLE_API_KEY, DEVICE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
# App Configuration
|
| 6 |
+
st.set_page_config(page_title="RxGuard Prescription Validator", page_icon="⚕️", layout="wide")
|
| 7 |
+
|
| 8 |
+
# Initialize directories and session state
|
| 9 |
+
UPLOADS_DIR = os.path.join(STATIC_DIR, "uploads")
|
| 10 |
+
os.makedirs(UPLOADS_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
|
|
|
| 12 |
if "analysis_result" not in st.session_state:
|
| 13 |
st.session_state.analysis_result = None
|
| 14 |
if "uploaded_filename" not in st.session_state:
|
| 15 |
st.session_state.uploaded_filename = None
|
| 16 |
|
|
|
|
| 17 |
def show_service_status():
|
| 18 |
"""Displays service connectivity status."""
|
| 19 |
+
st.caption("Service Status")
|
| 20 |
+
cols = st.columns(3)
|
| 21 |
+
cols[0].metric("HuggingFace Models", "✅" if HF_TOKEN else "❌")
|
| 22 |
+
cols[1].metric("Google AI Services", "✅" if GOOGLE_API_KEY else "❌")
|
| 23 |
+
cols[2].metric("Hardware Accelerator", DEVICE.upper())
|
| 24 |
+
st.divider()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
| 26 |
def main():
|
| 27 |
st.title("⚕️ RxGuard Prescription Validator")
|
| 28 |
+
st.caption("Advanced, multi-source AI verification system")
|
|
|
|
| 29 |
show_service_status()
|
| 30 |
+
|
| 31 |
# Only enable upload if required services are available
|
| 32 |
+
if all([HF_TOKEN, GOOGLE_API_KEY]):
|
| 33 |
uploaded_file = st.file_uploader(
|
| 34 |
+
"Upload a prescription image (PNG/JPG/JPEG):",
|
| 35 |
type=["png", "jpg", "jpeg"],
|
| 36 |
+
help="Upload a clear image of the prescription for analysis."
|
| 37 |
)
|
| 38 |
+
|
| 39 |
if uploaded_file and uploaded_file.name != st.session_state.uploaded_filename:
|
| 40 |
with st.status("Analyzing prescription...", expanded=True) as status:
|
| 41 |
try:
|
|
|
|
| 42 |
st.session_state.uploaded_filename = uploaded_file.name
|
| 43 |
file_path = os.path.join(UPLOADS_DIR, uploaded_file.name)
|
|
|
|
| 44 |
with open(file_path, "wb") as f:
|
| 45 |
f.write(uploaded_file.getvalue())
|
| 46 |
+
|
| 47 |
+
# Lazily import the processing function
|
| 48 |
from validate_prescription import extract_prescription_info
|
| 49 |
st.session_state.analysis_result = extract_prescription_info(file_path)
|
|
|
|
| 50 |
status.update(label="Analysis complete!", state="complete", expanded=False)
|
| 51 |
except Exception as e:
|
| 52 |
+
st.error(f"A critical error occurred during processing: {str(e)}")
|
| 53 |
st.session_state.analysis_result = {"error": str(e)}
|
| 54 |
status.update(label="Analysis failed", state="error")
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
st.error("Missing API Keys. Please configure HF_TOKEN and GOOGLE_API_KEY in your Space secrets.")
|
| 58 |
+
|
| 59 |
+
# Display results if available in the session state
|
| 60 |
+
if result := st.session_state.get("analysis_result"):
|
| 61 |
+
if error := result.get("error"):
|
| 62 |
+
st.error(f"❌ Analysis Error: {error}")
|
| 63 |
else:
|
| 64 |
+
info = result.get("info", {})
|
| 65 |
+
tab1, tab2 = st.tabs(["**👤 Patient & Prescription Info**", "**⚙️ Technical Details**"])
|
| 66 |
+
|
| 67 |
with tab1:
|
| 68 |
+
col1, col2 = st.columns([1, 2])
|
| 69 |
+
with col1:
|
| 70 |
+
if uploaded_file:
|
| 71 |
+
st.image(uploaded_file, use_column_width=True, caption="Uploaded Prescription")
|
| 72 |
+
with col2:
|
| 73 |
+
st.subheader("Patient Details")
|
| 74 |
+
st.info(f"**Name:** {info.get('Name', 'Not detected')}")
|
| 75 |
+
st.info(f"**Age:** {info.get('Age', 'N/A')}")
|
| 76 |
+
st.subheader("Prescription Details")
|
| 77 |
+
st.info(f"**Date:** {info.get('Date', 'N/A')}")
|
| 78 |
+
st.info(f"**Physician:** {info.get('PhysicianName', 'N/A')}")
|
| 79 |
+
|
| 80 |
+
st.divider()
|
| 81 |
+
st.subheader("💊 Medications")
|
| 82 |
+
for med in info.get("Medications", []):
|
| 83 |
+
st.success(f"**Drug:** {med.get('drug_raw')} | **Dosage:** {med.get('dosage', 'N/A')} | **Frequency:** {med.get('frequency', 'N/A')}")
|
| 84 |
+
|
| 85 |
with tab2:
|
| 86 |
+
st.subheader("Debug Information from AI Pipeline")
|
| 87 |
+
st.json(result.get("debug_info", {}))
|
|
|
|
|
|
|
| 88 |
|
| 89 |
if __name__ == "__main__":
|
| 90 |
main()
|
config.py
CHANGED
|
@@ -2,37 +2,36 @@ import os
|
|
| 2 |
import torch
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
|
|
|
|
| 5 |
load_dotenv()
|
| 6 |
|
| 7 |
-
# ─── Directory
|
| 8 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
STATIC_DIR = os.path.join(BASE_DIR, 'static')
|
| 10 |
os.makedirs(STATIC_DIR, exist_ok=True)
|
| 11 |
|
| 12 |
-
# ───
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
GOOGLE_APPLICATION_CREDENTIALS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
| 19 |
|
| 20 |
-
#
|
| 21 |
HF_MODELS = {
|
| 22 |
-
|
| 23 |
-
"
|
|
|
|
|
|
|
| 24 |
}
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# ─── Processing Parameters ─────────────────────────────────────────────────
|
| 28 |
-
LEV_THRESH = 0.75 # Levenshtein similarity threshold
|
| 29 |
-
SIG_THRESH = 0.65 # Signature verification threshold
|
| 30 |
|
| 31 |
-
# ─── File Paths ─────────────────────────────
|
| 32 |
DB_PATH = os.path.join(STATIC_DIR, "rxguard.db")
|
| 33 |
-
|
| 34 |
-
os.makedirs(
|
| 35 |
-
|
| 36 |
-
# ─── Hardware Configuration ────────────────────────────────────────────────
|
| 37 |
-
DEVICE = "cpu" # Force CPU for Hugging Face Spaces compatibility
|
| 38 |
-
USE_GPU = False
|
|
|
|
| 2 |
import torch
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
|
| 5 |
+
# Load environment variables from a .env file if it exists
|
| 6 |
load_dotenv()
|
| 7 |
|
| 8 |
+
# ─── Environment & Directory Setup ────────────────────────────────────────────
|
| 9 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
STATIC_DIR = os.path.join(BASE_DIR, 'static')
|
| 11 |
os.makedirs(STATIC_DIR, exist_ok=True)
|
| 12 |
|
| 13 |
+
# ─── Hardware Configuration ───────────────────────────────────────────────────
|
| 14 |
+
# Automatically use GPU if available (recommended for Hugging Face Spaces with T4)
|
| 15 |
+
USE_GPU = torch.cuda.is_available()
|
| 16 |
+
DEVICE = "cuda" if USE_GPU else "cpu"
|
| 17 |
+
|
| 18 |
+
# ─── API & Model Configuration ────────────────────────────────────────────────
|
| 19 |
+
# API Keys should be set as Secrets in your Hugging Face Space
|
| 20 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 21 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
| 22 |
GOOGLE_APPLICATION_CREDENTIALS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
| 23 |
|
| 24 |
+
# Model IDs for the hybrid pipeline
|
| 25 |
HF_MODELS = {
|
| 26 |
+
# Layout-aware model for initial structured extraction
|
| 27 |
+
"donut": "Javeria98/donut-base-Medical_Handwritten_Prescriptions_Information_Extraction_Final_model1",
|
| 28 |
+
# Small, powerful model for re-parsing medication details
|
| 29 |
+
"phi3": "Muizzzz8/phi3-prescription-reader"
|
| 30 |
}
|
| 31 |
+
# Final resolver model
|
| 32 |
+
GEMINI_MODEL_NAME = "gemini-1.5-flash"
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
# ─── File Paths (can be used for other utilities) ─────────────────────────────
|
| 35 |
DB_PATH = os.path.join(STATIC_DIR, "rxguard.db")
|
| 36 |
+
SIGNATURES_DIR = os.path.join(STATIC_DIR, "signatures")
|
| 37 |
+
os.makedirs(SIGNATURES_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,23 +1,23 @@
|
|
| 1 |
# Core
|
| 2 |
streamlit==1.36.0
|
| 3 |
python-dotenv==1.0.1
|
|
|
|
| 4 |
|
| 5 |
-
# AI & Vision
|
| 6 |
-
# Using Google's recommended versions for Gemini and Vision
|
| 7 |
google-generativeai==0.7.1
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
torch==2.3.1
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
# OCR
|
| 14 |
paddleocr==2.7.3
|
| 15 |
-
# Using the CPU version of paddlepaddle for broader compatibility on HF Spaces
|
| 16 |
paddlepaddle==2.6.1
|
| 17 |
|
| 18 |
# Utils
|
| 19 |
numpy==1.26.4
|
| 20 |
-
|
| 21 |
-
opencv-python-headless==4.10.0.84
|
| 22 |
-
scikit-image==0.22.0
|
| 23 |
-
pytz==2024.1
|
|
|
|
| 1 |
# Core
|
| 2 |
streamlit==1.36.0
|
| 3 |
python-dotenv==1.0.1
|
| 4 |
+
pandas==2.2.2
|
| 5 |
|
| 6 |
+
# AI & Vision - Google
|
|
|
|
| 7 |
google-generativeai==0.7.1
|
| 8 |
+
|
| 9 |
+
# AI & Vision - Hugging Face (for T4 GPU with CUDA 12.1)
|
| 10 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 11 |
torch==2.3.1
|
| 12 |
+
transformers==4.42.3
|
| 13 |
+
accelerate==0.31.0
|
| 14 |
+
bitsandbytes==0.43.1
|
| 15 |
+
sentencepiece==0.2.0
|
| 16 |
|
| 17 |
+
# OCR (as a potential fallback or utility)
|
| 18 |
paddleocr==2.7.3
|
|
|
|
| 19 |
paddlepaddle==2.6.1
|
| 20 |
|
| 21 |
# Utils
|
| 22 |
numpy==1.26.4
|
| 23 |
+
Pillow==10.3.0
|
|
|
|
|
|
|
|
|
validate_prescription.py
CHANGED
|
@@ -2,23 +2,17 @@ import os
|
|
| 2 |
import re
|
| 3 |
import json
|
| 4 |
import logging
|
| 5 |
-
import time
|
| 6 |
-
import numpy as np
|
| 7 |
import tempfile
|
| 8 |
-
import sqlite3
|
| 9 |
import torch
|
| 10 |
-
import
|
| 11 |
-
from typing import Dict, Any, List
|
| 12 |
from PIL import Image
|
| 13 |
-
from dotenv import load_dotenv
|
| 14 |
-
from googleapiclient.discovery import build
|
| 15 |
|
| 16 |
# Suppress verbose backend logs
|
| 17 |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 18 |
-
load_dotenv()
|
| 19 |
|
| 20 |
try:
|
| 21 |
-
from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel
|
| 22 |
from huggingface_hub import login
|
| 23 |
import google.generativeai as genai
|
| 24 |
except ImportError as e:
|
|
@@ -26,68 +20,25 @@ except ImportError as e:
|
|
| 26 |
raise
|
| 27 |
|
| 28 |
from config import (
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
HF_MODELS,
|
| 32 |
-
GOOGLE_API_KEY,
|
| 33 |
-
GOOGLE_APPLICATION_CREDENTIALS,
|
| 34 |
-
GEMINI_MODEL_NAME,
|
| 35 |
-
DEVICE,
|
| 36 |
-
USE_GPU,
|
| 37 |
-
GOOGLE_CSE_ID,
|
| 38 |
)
|
| 39 |
|
| 40 |
-
#
|
| 41 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 42 |
-
if HF_TOKEN:
|
| 43 |
-
login(token=HF_TOKEN)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
_TEMP_CRED_FILE = None
|
| 48 |
-
|
| 49 |
-
class GoogleSearch:
|
| 50 |
-
"""Performs Google Custom Search API queries."""
|
| 51 |
def __init__(self):
|
| 52 |
-
self.
|
| 53 |
-
self.
|
| 54 |
-
self.
|
| 55 |
-
if self.api_key and self.cse_id:
|
| 56 |
-
try:
|
| 57 |
-
self.service = build("customsearch", "v1", developerKey=self.api_key)
|
| 58 |
-
logging.info("Google Custom Search initialized.")
|
| 59 |
-
except Exception as e:
|
| 60 |
-
logging.error(f"CSE init failed: {e}")
|
| 61 |
-
else:
|
| 62 |
-
logging.warning("GOOGLE_API_KEY or GOOGLE_CSE_ID not set; search disabled.")
|
| 63 |
-
|
| 64 |
-
def search(self, queries: list, num_results: int = 1) -> list:
|
| 65 |
-
if not self.service:
|
| 66 |
-
return []
|
| 67 |
-
out = []
|
| 68 |
-
for q in queries:
|
| 69 |
-
try:
|
| 70 |
-
resp = self.service.cse().list(q=q, cx=self.cse_id, num=num_results).execute()
|
| 71 |
-
items = resp.get("items", [])
|
| 72 |
-
formatted = [
|
| 73 |
-
{"title": it.get("title"), "link": it.get("link"), "snippet": it.get("snippet")}
|
| 74 |
-
for it in items
|
| 75 |
-
]
|
| 76 |
-
out.append({"query": q, "results": formatted})
|
| 77 |
-
except Exception as e:
|
| 78 |
-
logging.error(f"Search error for '{q}': {e}")
|
| 79 |
-
out.append({"query": q, "results": []})
|
| 80 |
-
return out
|
| 81 |
-
|
| 82 |
-
# Initialize Google Search globally
|
| 83 |
-
google_search = GoogleSearch()
|
| 84 |
|
| 85 |
-
def
|
| 86 |
-
|
| 87 |
-
if name not in _MODELS:
|
| 88 |
model_id = HF_MODELS.get(name)
|
| 89 |
-
if not model_id:
|
| 90 |
-
return None
|
| 91 |
|
| 92 |
logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...")
|
| 93 |
try:
|
|
@@ -95,242 +46,92 @@ def get_model(name: str):
|
|
| 95 |
if name == "donut":
|
| 96 |
processor = DonutProcessor.from_pretrained(model_id)
|
| 97 |
model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config)
|
| 98 |
-
|
| 99 |
elif name == "phi3":
|
| 100 |
-
model = pipeline(
|
| 101 |
-
|
| 102 |
-
model=model_id,
|
| 103 |
-
torch_dtype=torch.bfloat16,
|
| 104 |
-
trust_remote_code=True,
|
| 105 |
-
**quantization_config
|
| 106 |
-
)
|
| 107 |
-
_MODELS[name] = {"model": model}
|
| 108 |
logging.info(f"Model '{name}' loaded successfully.")
|
| 109 |
except Exception as e:
|
| 110 |
logging.error(f"Failed to load model '{name}': {e}", exc_info=True)
|
| 111 |
-
_MODELS[name] = None
|
| 112 |
-
|
| 113 |
-
return _MODELS.get(name)
|
| 114 |
-
|
| 115 |
-
def get_gemini_client():
|
| 116 |
-
"""Initializes and returns the Gemini client."""
|
| 117 |
-
global _TEMP_CRED_FILE
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS:
|
| 122 |
-
if not
|
| 123 |
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp:
|
| 124 |
tfp.write(creds_json_str)
|
| 125 |
-
|
| 126 |
-
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] =
|
| 127 |
-
|
| 128 |
try:
|
| 129 |
genai.configure(api_key=GOOGLE_API_KEY)
|
| 130 |
-
|
| 131 |
except Exception as e:
|
| 132 |
logging.error(f"Gemini init failed: {e}")
|
| 133 |
-
_MODELS["gemini"] = None
|
| 134 |
-
|
| 135 |
-
return _MODELS.get("gemini")
|
| 136 |
-
|
| 137 |
-
def verify_medication_with_google(medication_name: str) -> Dict[str, Any]:
|
| 138 |
-
"""Verifies medication details using Google Search API."""
|
| 139 |
-
if not medication_name:
|
| 140 |
-
return {"error": "No medication name provided"}
|
| 141 |
-
|
| 142 |
-
queries = [
|
| 143 |
-
f"standard dosage for {medication_name}",
|
| 144 |
-
f"side effects of {medication_name}",
|
| 145 |
-
f"drug interactions with {medication_name}"
|
| 146 |
-
]
|
| 147 |
-
|
| 148 |
-
search_results = google_search.search(queries, num_results=2)
|
| 149 |
-
return {
|
| 150 |
-
"medication": medication_name,
|
| 151 |
-
"verification_results": search_results
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
def step1_run_donut(image: Image.Image) -> Dict[str, Any]:
|
| 155 |
-
donut_components = get_model("donut")
|
| 156 |
-
if not donut_components:
|
| 157 |
-
return {"error": "Donut model not available."}
|
| 158 |
-
|
| 159 |
-
model = donut_components["model"].to(DEVICE)
|
| 160 |
-
processor = donut_components["processor"]
|
| 161 |
-
|
| 162 |
-
task_prompt = "<s_cord-v2>"
|
| 163 |
-
decoder_input_ids = processor.tokenizer(
|
| 164 |
-
task_prompt, add_special_tokens=False, return_tensors="pt"
|
| 165 |
-
).input_ids.to(DEVICE)
|
| 166 |
-
pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE)
|
| 167 |
-
|
| 168 |
-
outputs = model.generate(
|
| 169 |
-
pixel_values,
|
| 170 |
-
decoder_input_ids=decoder_input_ids,
|
| 171 |
-
max_length=model.decoder.config.max_position_embeddings,
|
| 172 |
-
early_stopping=True,
|
| 173 |
-
use_cache=True,
|
| 174 |
-
num_beams=1,
|
| 175 |
-
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
| 176 |
-
return_dict_in_generate=True,
|
| 177 |
-
)
|
| 178 |
-
sequence = (
|
| 179 |
-
processor.batch_decode(outputs.sequences)[0]
|
| 180 |
-
.replace(processor.tokenizer.eos_token, "")
|
| 181 |
-
.replace(processor.tokenizer.pad_token, "")
|
| 182 |
-
)
|
| 183 |
-
return processor.token2json(sequence)
|
| 184 |
-
|
| 185 |
-
def step2_run_phi3(medication_text: str) -> str:
|
| 186 |
-
phi3_components = get_model("phi3")
|
| 187 |
-
if not phi3_components:
|
| 188 |
-
return medication_text
|
| 189 |
-
|
| 190 |
-
pipe = phi3_components["model"]
|
| 191 |
-
prompt = (
|
| 192 |
-
f"Normalize the following prescription medication line into its components "
|
| 193 |
-
f"(drug, dosage, frequency). Raw text: '{medication_text}'"
|
| 194 |
-
)
|
| 195 |
-
outputs = pipe(prompt, max_new_tokens=100, do_sample=False)
|
| 196 |
-
return outputs[0]["generated_text"].split("Normalized:")[-1].strip()
|
| 197 |
-
|
| 198 |
-
def step3_run_gemini_resolver(
|
| 199 |
-
image: Image.Image, donut_result: Dict[str, Any], phi3_results: List[str]
|
| 200 |
-
) -> Dict[str, Any]:
|
| 201 |
-
gemini_client = get_gemini_client()
|
| 202 |
-
if not gemini_client:
|
| 203 |
-
return {"error": "Gemini resolver not available."}
|
| 204 |
-
|
| 205 |
-
prompt = f"""
|
| 206 |
-
You are an expert pharmacist's assistant whose sole objective is to reconcile and verify prescription details by cross-referencing multiple AI model outputs against the original prescription image (the ultimate source of truth).
|
| 207 |
-
|
| 208 |
-
Attached Inputs:
|
| 209 |
-
1. Prescription image file
|
| 210 |
-
2. Donut Model Output (layout-aware):
|
| 211 |
-
{json.dumps(donut_result, indent=2)}
|
| 212 |
-
3. Phi-3 Model Output (medication refinement):
|
| 213 |
-
{json.dumps(phi3_results, indent=2)}
|
| 214 |
-
|
| 215 |
-
Please follow these steps **in order**:
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
result = json.loads(response.text)
|
| 264 |
-
|
| 265 |
-
# Enhance with Google Search verification
|
| 266 |
-
for med in result.get("Medications", []):
|
| 267 |
-
if drug_name := med.get("drug_raw"):
|
| 268 |
-
verification = verify_medication_with_google(drug_name)
|
| 269 |
-
med["verification"] = {
|
| 270 |
-
"standard_dosage": self._extract_dosage_info(verification),
|
| 271 |
-
"side_effects": self._extract_side_effects(verification),
|
| 272 |
-
"interactions": self._extract_interactions(verification)
|
| 273 |
-
}
|
| 274 |
-
|
| 275 |
-
return result
|
| 276 |
-
except Exception as e:
|
| 277 |
-
logging.error(f"Gemini resolver failed: {e}")
|
| 278 |
-
return {"error": f"Gemini failed to resolve data: {e}"}
|
| 279 |
-
|
| 280 |
-
def _extract_dosage_info(verification_data: Dict) -> Optional[str]:
|
| 281 |
-
"""Extracts dosage information from verification results."""
|
| 282 |
-
for result in verification_data.get("verification_results", []):
|
| 283 |
-
if "standard dosage" in result.get("query", "").lower():
|
| 284 |
-
return result.get("results", [{}])[0].get("snippet")
|
| 285 |
-
return None
|
| 286 |
|
| 287 |
-
def
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
if "interactions" in result.get("query", "").lower():
|
| 298 |
-
return result.get("results", [{}])[0].get("snippet")
|
| 299 |
-
return None
|
| 300 |
|
| 301 |
def extract_prescription_info(image_path: str) -> Dict[str, Any]:
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
image = Image.open(image_path).convert("RGB")
|
| 305 |
-
|
| 306 |
-
logging.info("Step 1: Running Donut model for layout analysis...")
|
| 307 |
-
donut_data = step1_run_donut(image)
|
| 308 |
-
|
| 309 |
-
medication_lines = [
|
| 310 |
-
item.get("text", "")
|
| 311 |
-
for item in donut_data.get("menu", [])
|
| 312 |
-
if "medi" in item.get("category", "").lower()
|
| 313 |
-
]
|
| 314 |
-
|
| 315 |
-
logging.info("Step 2: Running Phi-3 model for medication refinement...")
|
| 316 |
-
phi3_refined_meds = [step2_run_phi3(line) for line in medication_lines]
|
| 317 |
-
|
| 318 |
-
logging.info("Step 3: Running Gemini model as the expert resolver...")
|
| 319 |
-
final_info = step3_run_gemini_resolver(image, donut_data, phi3_refined_meds)
|
| 320 |
-
|
| 321 |
-
if final_info.get("error"):
|
| 322 |
-
return final_info
|
| 323 |
-
|
| 324 |
-
result = {
|
| 325 |
-
"info": final_info,
|
| 326 |
-
"error": None,
|
| 327 |
-
"debug_info": {
|
| 328 |
-
"donut_output": donut_data,
|
| 329 |
-
"phi3_refinements": phi3_refined_meds,
|
| 330 |
-
},
|
| 331 |
-
}
|
| 332 |
-
return result
|
| 333 |
-
|
| 334 |
-
except Exception as e:
|
| 335 |
-
logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True)
|
| 336 |
-
return {"error": f"An unexpected error occurred in the pipeline: {e}"}
|
|
|
|
| 2 |
import re
|
| 3 |
import json
|
| 4 |
import logging
|
|
|
|
|
|
|
| 5 |
import tempfile
|
|
|
|
| 6 |
import torch
|
| 7 |
+
import streamlit as st
|
| 8 |
+
from typing import Dict, Any, List
|
| 9 |
from PIL import Image
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Suppress verbose backend logs
|
| 12 |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
|
|
|
| 13 |
|
| 14 |
try:
|
| 15 |
+
from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel
|
| 16 |
from huggingface_hub import login
|
| 17 |
import google.generativeai as genai
|
| 18 |
except ImportError as e:
|
|
|
|
| 20 |
raise
|
| 21 |
|
| 22 |
from config import (
|
| 23 |
+
HF_TOKEN, HF_MODELS, GOOGLE_API_KEY,
|
| 24 |
+
GOOGLE_APPLICATION_CREDENTIALS, GEMINI_MODEL_NAME, DEVICE, USE_GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
+
# Configure Logging & Auth
|
| 28 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 29 |
+
if HF_TOKEN: login(token=HF_TOKEN)
|
|
|
|
| 30 |
|
| 31 |
+
class PrescriptionProcessor:
|
| 32 |
+
"""Encapsulates the entire hybrid pipeline to resolve the 'self' error."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def __init__(self):
|
| 34 |
+
self.model_cache = {}
|
| 35 |
+
self.temp_cred_file = None
|
| 36 |
+
self._load_all_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
def _load_model(self, name: str):
|
| 39 |
+
if name in self.model_cache: return
|
|
|
|
| 40 |
model_id = HF_MODELS.get(name)
|
| 41 |
+
if not model_id: return
|
|
|
|
| 42 |
|
| 43 |
logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...")
|
| 44 |
try:
|
|
|
|
| 46 |
if name == "donut":
|
| 47 |
processor = DonutProcessor.from_pretrained(model_id)
|
| 48 |
model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config)
|
| 49 |
+
self.model_cache[name] = {"model": model, "processor": processor}
|
| 50 |
elif name == "phi3":
|
| 51 |
+
model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, **quantization_config)
|
| 52 |
+
self.model_cache[name] = {"model": model}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
logging.info(f"Model '{name}' loaded successfully.")
|
| 54 |
except Exception as e:
|
| 55 |
logging.error(f"Failed to load model '{name}': {e}", exc_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
def _load_gemini_client(self):
|
| 58 |
+
if "gemini" in self.model_cache: return
|
| 59 |
if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS:
|
| 60 |
+
if not self.temp_cred_file or not os.path.exists(self.temp_cred_file):
|
| 61 |
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp:
|
| 62 |
tfp.write(creds_json_str)
|
| 63 |
+
self.temp_cred_file = tfp.name
|
| 64 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.temp_cred_file
|
|
|
|
| 65 |
try:
|
| 66 |
genai.configure(api_key=GOOGLE_API_KEY)
|
| 67 |
+
self.model_cache["gemini"] = genai.GenerativeModel(GEMINI_MODEL_NAME)
|
| 68 |
except Exception as e:
|
| 69 |
logging.error(f"Gemini init failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
def _load_all_models(self):
|
| 72 |
+
self._load_model("donut")
|
| 73 |
+
self._load_model("phi3")
|
| 74 |
+
self._load_gemini_client()
|
| 75 |
+
|
| 76 |
+
def _run_donut(self, image: Image.Image) -> Dict[str, Any]:
|
| 77 |
+
components = self.model_cache.get("donut")
|
| 78 |
+
if not components: return {"error": "Donut model not available."}
|
| 79 |
+
model, processor = components["model"].to(DEVICE), components["processor"]
|
| 80 |
+
task_prompt = "<s_cord-v2>"
|
| 81 |
+
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE)
|
| 82 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE)
|
| 83 |
+
outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True)
|
| 84 |
+
sequence = processor.batch_decode(outputs.sequences)[0].replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
| 85 |
+
return processor.token2json(sequence)
|
| 86 |
+
|
| 87 |
+
def _run_phi3(self, medication_text: str) -> str:
|
| 88 |
+
components = self.model_cache.get("phi3")
|
| 89 |
+
if not components: return medication_text
|
| 90 |
+
pipe = components["model"]
|
| 91 |
+
prompt = f"Normalize the following prescription medication line into its components (drug, dosage, frequency). Raw text: '{medication_text}'"
|
| 92 |
+
outputs = pipe(prompt, max_new_tokens=100, do_sample=False)
|
| 93 |
+
return outputs[0]['generated_text'].split("Normalized:")[-1].strip()
|
| 94 |
+
|
| 95 |
+
def _run_gemini_resolver(self, image: Image.Image, donut_result: Dict, phi3_results: List[str]) -> Dict[str, Any]:
|
| 96 |
+
gemini_client = self.model_cache.get("gemini")
|
| 97 |
+
if not gemini_client: return {"error": "Gemini resolver not available."}
|
| 98 |
+
prompt = f"""
|
| 99 |
+
You are an expert pharmacist’s assistant...
|
| 100 |
+
(Your detailed prompt from the previous turn goes here)
|
| 101 |
+
...
|
| 102 |
+
**Final JSON Schema**
|
| 103 |
+
```json
|
| 104 |
+
{{
|
| 105 |
+
"Name": "string or null", "Date": "string (MM/DD/YYYY) or null", "Age": "string or null", "PhysicianName": "string or null",
|
| 106 |
+
"Medications": [{{"drug_raw": "string", "dosage": "string or null", "frequency": "string or null"}}]
|
| 107 |
+
}}
|
| 108 |
+
```
|
| 109 |
+
"""
|
| 110 |
+
try:
|
| 111 |
+
response = gemini_client.generate_content([prompt, image], generation_config={"response_mime_type": "application/json"})
|
| 112 |
+
return json.loads(response.text)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logging.error(f"Gemini resolver failed: {e}")
|
| 115 |
+
# This is where your original error was being generated from
|
| 116 |
+
return {"error": f"Gemini failed to resolve data: {e}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
def process(self, image_path: str) -> Dict[str, Any]:
|
| 119 |
+
try:
|
| 120 |
+
image = Image.open(image_path).convert("RGB")
|
| 121 |
+
donut_data = self._run_donut(image)
|
| 122 |
+
med_lines = [item.get('text', '') for item in donut_data.get('menu', []) if 'medi' in item.get('category', '').lower()]
|
| 123 |
+
phi3_refined_meds = [self._run_phi3(line) for line in med_lines]
|
| 124 |
+
final_info = self._run_gemini_resolver(image, donut_data, phi3_refined_meds)
|
| 125 |
+
if final_info.get("error"): return final_info
|
| 126 |
+
return {"info": final_info, "error": None, "debug_info": {"donut_output": donut_data, "phi3_refinements": phi3_refined_meds}}
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True)
|
| 129 |
+
return {"error": f"An unexpected error occurred in the pipeline: {e}"}
|
| 130 |
|
| 131 |
+
@st.cache_resource
|
| 132 |
+
def get_processor():
|
| 133 |
+
return PrescriptionProcessor()
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
def extract_prescription_info(image_path: str) -> Dict[str, Any]:
|
| 136 |
+
processor = get_processor()
|
| 137 |
+
return processor.process(image_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|