Ankit Thakur commited on
Commit
c57fdf3
·
1 Parent(s): bf85b2f

everything

Browse files
Files changed (4) hide show
  1. app.py +52 -99
  2. config.py +21 -22
  3. requirements.txt +11 -11
  4. validate_prescription.py +87 -286
app.py CHANGED
@@ -1,137 +1,90 @@
1
  import os
2
  import streamlit as st
3
- from PIL import Image
4
- from config import (
5
- STATIC_DIR,
6
- UPLOADS_DIR,
7
- HF_TOKEN,
8
- GOOGLE_API_KEY,
9
- GOOGLE_CSE_ID,
10
- GEMINI_API_KEY,
11
- DEVICE
12
- )
13
 
14
- # ─── App Configuration ────────────────────────────────────────────────────
15
- st.set_page_config(
16
- page_title="RxGuard Prescription Validator",
17
- page_icon="⚕️",
18
- layout="wide",
19
- menu_items={
20
- 'Get Help': 'https://github.com/your-repo',
21
- 'About': "RxGuard v1.0 - Advanced Prescription Validation"
22
- }
23
- )
24
 
25
- # ─── Session State ────────────────────────────────────────────────────────
26
  if "analysis_result" not in st.session_state:
27
  st.session_state.analysis_result = None
28
  if "uploaded_filename" not in st.session_state:
29
  st.session_state.uploaded_filename = None
30
 
31
- # ─── UI Components ────────────────────────────────────────────────────────
32
  def show_service_status():
33
  """Displays service connectivity status."""
34
- cols = st.columns(4)
35
- with cols[0]:
36
- st.metric("HuggingFace", "✅" if HF_TOKEN else "❌")
37
- with cols[1]:
38
- st.metric("Google API", "✅" if GOOGLE_API_KEY else "❌")
39
- with cols[2]:
40
- st.metric("Gemini", "✅" if GEMINI_API_KEY else "❌")
41
- with cols[3]:
42
- st.metric("Device", DEVICE.upper())
43
-
44
- def display_patient_info(info: dict):
45
- """Displays patient information in a formatted card."""
46
- with st.container(border=True):
47
- st.subheader("👤 Patient Details")
48
- cols = st.columns(2)
49
- with cols[0]:
50
- st.markdown(f"**Name:** {info.get('Name', 'Not detected')}")
51
- st.markdown(f"**Age:** {info.get('Age', 'N/A')}")
52
- with cols[1]:
53
- st.markdown(f"**Date:** {info.get('Date', 'N/A')}")
54
- st.markdown(f"**Physician:** {info.get('PhysicianName', 'N/A')}")
55
-
56
- def display_medications(medications: list):
57
- """Displays medication information with verification."""
58
- st.subheader("💊 Medications")
59
- if not medications:
60
- st.warning("No medications detected in prescription")
61
- return
62
-
63
- for med in medications:
64
- with st.expander(f"{med.get('drug_raw', 'Unknown Medication')}"):
65
- cols = st.columns([1, 2])
66
- with cols[0]:
67
- st.markdown(f"""
68
- **Dosage:** `{med.get('dosage', 'N/A')}`
69
- **Frequency:** `{med.get('frequency', 'N/A')}`
70
- """)
71
-
72
- with cols[1]:
73
- if verification := med.get("verification"):
74
- if dosage := verification.get("standard_dosage"):
75
- st.success(f"**Standard Dosage:** {dosage}")
76
- if side_effects := verification.get("side_effects"):
77
- st.warning(f"**Side Effects:** {side_effects}")
78
- if interactions := verification.get("interactions"):
79
- st.error(f"**Interactions:** {interactions}")
80
 
81
- # ─── Main Application ─────────────────────────────────────────────────────
82
  def main():
83
  st.title("⚕️ RxGuard Prescription Validator")
84
- st.caption("AI-powered prescription verification system")
85
-
86
  show_service_status()
87
-
88
  # Only enable upload if required services are available
89
- if all([HF_TOKEN, GOOGLE_API_KEY, GEMINI_API_KEY]):
90
  uploaded_file = st.file_uploader(
91
- "Upload prescription image (PNG/JPG/JPEG):",
92
  type=["png", "jpg", "jpeg"],
93
- help="Clear image of the prescription"
94
  )
95
-
96
  if uploaded_file and uploaded_file.name != st.session_state.uploaded_filename:
97
  with st.status("Analyzing prescription...", expanded=True) as status:
98
  try:
99
- # Store the uploaded file
100
  st.session_state.uploaded_filename = uploaded_file.name
101
  file_path = os.path.join(UPLOADS_DIR, uploaded_file.name)
102
-
103
  with open(file_path, "wb") as f:
104
  f.write(uploaded_file.getvalue())
105
-
106
- # Import processing function only when needed
107
  from validate_prescription import extract_prescription_info
108
  st.session_state.analysis_result = extract_prescription_info(file_path)
109
-
110
  status.update(label="Analysis complete!", state="complete", expanded=False)
111
  except Exception as e:
112
- st.error(f"Processing failed: {str(e)}")
113
  st.session_state.analysis_result = {"error": str(e)}
114
  status.update(label="Analysis failed", state="error")
115
-
116
- # Display results if available
117
- if st.session_state.analysis_result:
118
- result = st.session_state.analysis_result
119
-
120
- if result.get("error"):
121
- st.error(f"❌ Error: {result['error']}")
 
122
  else:
123
- tab1, tab2 = st.tabs(["Patient Information", "Medication Details"])
124
-
 
125
  with tab1:
126
- if uploaded_file:
127
- st.image(uploaded_file, use_column_width=True)
128
- display_patient_info(result["info"])
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  with tab2:
131
- display_medications(result["info"].get("Medications", []))
132
-
133
- if st.toggle("Show technical details"):
134
- st.json(result.get("debug_info", {}))
135
 
136
  if __name__ == "__main__":
137
  main()
 
1
  import os
2
  import streamlit as st
3
+ from config import STATIC_DIR, HF_TOKEN, GOOGLE_API_KEY, DEVICE
 
 
 
 
 
 
 
 
 
4
 
5
+ # App Configuration
6
+ st.set_page_config(page_title="RxGuard Prescription Validator", page_icon="⚕️", layout="wide")
7
+
8
+ # Initialize directories and session state
9
+ UPLOADS_DIR = os.path.join(STATIC_DIR, "uploads")
10
+ os.makedirs(UPLOADS_DIR, exist_ok=True)
 
 
 
 
11
 
 
12
  if "analysis_result" not in st.session_state:
13
  st.session_state.analysis_result = None
14
  if "uploaded_filename" not in st.session_state:
15
  st.session_state.uploaded_filename = None
16
 
 
17
  def show_service_status():
18
  """Displays service connectivity status."""
19
+ st.caption("Service Status")
20
+ cols = st.columns(3)
21
+ cols[0].metric("HuggingFace Models", "✅" if HF_TOKEN else "❌")
22
+ cols[1].metric("Google AI Services", "✅" if GOOGLE_API_KEY else "❌")
23
+ cols[2].metric("Hardware Accelerator", DEVICE.upper())
24
+ st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
26
  def main():
27
  st.title("⚕️ RxGuard Prescription Validator")
28
+ st.caption("Advanced, multi-source AI verification system")
 
29
  show_service_status()
30
+
31
  # Only enable upload if required services are available
32
+ if all([HF_TOKEN, GOOGLE_API_KEY]):
33
  uploaded_file = st.file_uploader(
34
+ "Upload a prescription image (PNG/JPG/JPEG):",
35
  type=["png", "jpg", "jpeg"],
36
+ help="Upload a clear image of the prescription for analysis."
37
  )
38
+
39
  if uploaded_file and uploaded_file.name != st.session_state.uploaded_filename:
40
  with st.status("Analyzing prescription...", expanded=True) as status:
41
  try:
 
42
  st.session_state.uploaded_filename = uploaded_file.name
43
  file_path = os.path.join(UPLOADS_DIR, uploaded_file.name)
 
44
  with open(file_path, "wb") as f:
45
  f.write(uploaded_file.getvalue())
46
+
47
+ # Lazily import the processing function
48
  from validate_prescription import extract_prescription_info
49
  st.session_state.analysis_result = extract_prescription_info(file_path)
 
50
  status.update(label="Analysis complete!", state="complete", expanded=False)
51
  except Exception as e:
52
+ st.error(f"A critical error occurred during processing: {str(e)}")
53
  st.session_state.analysis_result = {"error": str(e)}
54
  status.update(label="Analysis failed", state="error")
55
+
56
+ else:
57
+ st.error("Missing API Keys. Please configure HF_TOKEN and GOOGLE_API_KEY in your Space secrets.")
58
+
59
+ # Display results if available in the session state
60
+ if result := st.session_state.get("analysis_result"):
61
+ if error := result.get("error"):
62
+ st.error(f"❌ Analysis Error: {error}")
63
  else:
64
+ info = result.get("info", {})
65
+ tab1, tab2 = st.tabs(["**👤 Patient & Prescription Info**", "**⚙️ Technical Details**"])
66
+
67
  with tab1:
68
+ col1, col2 = st.columns([1, 2])
69
+ with col1:
70
+ if uploaded_file:
71
+ st.image(uploaded_file, use_column_width=True, caption="Uploaded Prescription")
72
+ with col2:
73
+ st.subheader("Patient Details")
74
+ st.info(f"**Name:** {info.get('Name', 'Not detected')}")
75
+ st.info(f"**Age:** {info.get('Age', 'N/A')}")
76
+ st.subheader("Prescription Details")
77
+ st.info(f"**Date:** {info.get('Date', 'N/A')}")
78
+ st.info(f"**Physician:** {info.get('PhysicianName', 'N/A')}")
79
+
80
+ st.divider()
81
+ st.subheader("💊 Medications")
82
+ for med in info.get("Medications", []):
83
+ st.success(f"**Drug:** {med.get('drug_raw')} | **Dosage:** {med.get('dosage', 'N/A')} | **Frequency:** {med.get('frequency', 'N/A')}")
84
+
85
  with tab2:
86
+ st.subheader("Debug Information from AI Pipeline")
87
+ st.json(result.get("debug_info", {}))
 
 
88
 
89
  if __name__ == "__main__":
90
  main()
config.py CHANGED
@@ -2,37 +2,36 @@ import os
2
  import torch
3
  from dotenv import load_dotenv
4
 
 
5
  load_dotenv()
6
 
7
- # ─── Directory Configuration ────────────────────────────────────────────────
8
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
  STATIC_DIR = os.path.join(BASE_DIR, 'static')
10
  os.makedirs(STATIC_DIR, exist_ok=True)
11
 
12
- # ─── API Secrets ────────────────────────────────────────────────────────────
13
- HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") # For Hugging Face models
14
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # For Gemini and Custom Search
15
- GOOGLE_CSE_ID = os.getenv("GOOGLE_CSE_ID") # For medication verification
16
- GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Alternative Gemini auth
17
- HALODOC_API_KEY = os.getenv("HALODOC_API_KEY") # Future integration
 
 
 
18
  GOOGLE_APPLICATION_CREDENTIALS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
19
 
20
- # ─── Model Configuration ────────────────────────────────────────────────────
21
  HF_MODELS = {
22
- "donut": "naver-clova-ix/donut-base-finetuned-cord-v2",
23
- "phi3": "microsoft/phi-3-mini-4k-instruct",
 
 
24
  }
25
- GEMINI_MODEL_NAME = "gemini-1.5-flash" # Balanced for speed and accuracy
26
-
27
- # ─── Processing Parameters ─────────────────────────────────────────────────
28
- LEV_THRESH = 0.75 # Levenshtein similarity threshold
29
- SIG_THRESH = 0.65 # Signature verification threshold
30
 
31
- # ─── File Paths ───────────────────────────────────────────────────────────
32
  DB_PATH = os.path.join(STATIC_DIR, "rxguard.db")
33
- UPLOADS_DIR = os.path.join(STATIC_DIR, "uploads")
34
- os.makedirs(UPLOADS_DIR, exist_ok=True)
35
-
36
- # ─── Hardware Configuration ────────────────────────────────────────────────
37
- DEVICE = "cpu" # Force CPU for Hugging Face Spaces compatibility
38
- USE_GPU = False
 
2
  import torch
3
  from dotenv import load_dotenv
4
 
5
+ # Load environment variables from a .env file if it exists
6
  load_dotenv()
7
 
8
+ # ─── Environment & Directory Setup ────────────────────────────────────────────
9
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10
  STATIC_DIR = os.path.join(BASE_DIR, 'static')
11
  os.makedirs(STATIC_DIR, exist_ok=True)
12
 
13
+ # ─── Hardware Configuration ───────────────────────────────────────────────────
14
+ # Automatically use GPU if available (recommended for Hugging Face Spaces with T4)
15
+ USE_GPU = torch.cuda.is_available()
16
+ DEVICE = "cuda" if USE_GPU else "cpu"
17
+
18
+ # ─── API & Model Configuration ────────────────────────────────────────────────
19
+ # API Keys should be set as Secrets in your Hugging Face Space
20
+ HF_TOKEN = os.getenv("HF_TOKEN")
21
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
22
  GOOGLE_APPLICATION_CREDENTIALS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
23
 
24
+ # Model IDs for the hybrid pipeline
25
  HF_MODELS = {
26
+ # Layout-aware model for initial structured extraction
27
+ "donut": "Javeria98/donut-base-Medical_Handwritten_Prescriptions_Information_Extraction_Final_model1",
28
+ # Small, powerful model for re-parsing medication details
29
+ "phi3": "Muizzzz8/phi3-prescription-reader"
30
  }
31
+ # Final resolver model
32
+ GEMINI_MODEL_NAME = "gemini-1.5-flash"
 
 
 
33
 
34
+ # ─── File Paths (can be used for other utilities) ─────────────────────────────
35
  DB_PATH = os.path.join(STATIC_DIR, "rxguard.db")
36
+ SIGNATURES_DIR = os.path.join(STATIC_DIR, "signatures")
37
+ os.makedirs(SIGNATURES_DIR, exist_ok=True)
 
 
 
 
requirements.txt CHANGED
@@ -1,23 +1,23 @@
1
  # Core
2
  streamlit==1.36.0
3
  python-dotenv==1.0.1
 
4
 
5
- # AI & Vision
6
- # Using Google's recommended versions for Gemini and Vision
7
  google-generativeai==0.7.1
8
- google-cloud-vision==3.7.3
 
 
9
  torch==2.3.1
10
- pillow==10.3.0
11
- transformers==4.41.0
 
 
12
 
13
- # OCR
14
  paddleocr==2.7.3
15
- # Using the CPU version of paddlepaddle for broader compatibility on HF Spaces
16
  paddlepaddle==2.6.1
17
 
18
  # Utils
19
  numpy==1.26.4
20
- requests==2.32.3
21
- opencv-python-headless==4.10.0.84
22
- scikit-image==0.22.0
23
- pytz==2024.1
 
1
  # Core
2
  streamlit==1.36.0
3
  python-dotenv==1.0.1
4
+ pandas==2.2.2
5
 
6
+ # AI & Vision - Google
 
7
  google-generativeai==0.7.1
8
+
9
+ # AI & Vision - Hugging Face (for T4 GPU with CUDA 12.1)
10
+ --extra-index-url https://download.pytorch.org/whl/cu121
11
  torch==2.3.1
12
+ transformers==4.42.3
13
+ accelerate==0.31.0
14
+ bitsandbytes==0.43.1
15
+ sentencepiece==0.2.0
16
 
17
+ # OCR (as a potential fallback or utility)
18
  paddleocr==2.7.3
 
19
  paddlepaddle==2.6.1
20
 
21
  # Utils
22
  numpy==1.26.4
23
+ Pillow==10.3.0
 
 
 
validate_prescription.py CHANGED
@@ -2,23 +2,17 @@ import os
2
  import re
3
  import json
4
  import logging
5
- import time
6
- import numpy as np
7
  import tempfile
8
- import sqlite3
9
  import torch
10
- import io
11
- from typing import Dict, Any, List, Optional
12
  from PIL import Image
13
- from dotenv import load_dotenv
14
- from googleapiclient.discovery import build
15
 
16
  # Suppress verbose backend logs
17
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
18
- load_dotenv()
19
 
20
  try:
21
- from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer
22
  from huggingface_hub import login
23
  import google.generativeai as genai
24
  except ImportError as e:
@@ -26,68 +20,25 @@ except ImportError as e:
26
  raise
27
 
28
  from config import (
29
- DB_PATH,
30
- HF_TOKEN,
31
- HF_MODELS,
32
- GOOGLE_API_KEY,
33
- GOOGLE_APPLICATION_CREDENTIALS,
34
- GEMINI_MODEL_NAME,
35
- DEVICE,
36
- USE_GPU,
37
- GOOGLE_CSE_ID,
38
  )
39
 
40
- # ─── Configure Logging & Auth ────────────────────────────────────────────────
41
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
42
- if HF_TOKEN:
43
- login(token=HF_TOKEN)
44
 
45
- # ─── Singleton Holders for AI Models ─────────────────────────────────────────
46
- _MODELS = {}
47
- _TEMP_CRED_FILE = None
48
-
49
- class GoogleSearch:
50
- """Performs Google Custom Search API queries."""
51
  def __init__(self):
52
- self.api_key = GOOGLE_API_KEY
53
- self.cse_id = GOOGLE_CSE_ID
54
- self.service = None
55
- if self.api_key and self.cse_id:
56
- try:
57
- self.service = build("customsearch", "v1", developerKey=self.api_key)
58
- logging.info("Google Custom Search initialized.")
59
- except Exception as e:
60
- logging.error(f"CSE init failed: {e}")
61
- else:
62
- logging.warning("GOOGLE_API_KEY or GOOGLE_CSE_ID not set; search disabled.")
63
-
64
- def search(self, queries: list, num_results: int = 1) -> list:
65
- if not self.service:
66
- return []
67
- out = []
68
- for q in queries:
69
- try:
70
- resp = self.service.cse().list(q=q, cx=self.cse_id, num=num_results).execute()
71
- items = resp.get("items", [])
72
- formatted = [
73
- {"title": it.get("title"), "link": it.get("link"), "snippet": it.get("snippet")}
74
- for it in items
75
- ]
76
- out.append({"query": q, "results": formatted})
77
- except Exception as e:
78
- logging.error(f"Search error for '{q}': {e}")
79
- out.append({"query": q, "results": []})
80
- return out
81
-
82
- # Initialize Google Search globally
83
- google_search = GoogleSearch()
84
 
85
- def get_model(name: str):
86
- """Loads and caches AI models to avoid reloading."""
87
- if name not in _MODELS:
88
  model_id = HF_MODELS.get(name)
89
- if not model_id:
90
- return None
91
 
92
  logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...")
93
  try:
@@ -95,242 +46,92 @@ def get_model(name: str):
95
  if name == "donut":
96
  processor = DonutProcessor.from_pretrained(model_id)
97
  model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config)
98
- _MODELS[name] = {"model": model, "processor": processor}
99
  elif name == "phi3":
100
- model = pipeline(
101
- "text-generation",
102
- model=model_id,
103
- torch_dtype=torch.bfloat16,
104
- trust_remote_code=True,
105
- **quantization_config
106
- )
107
- _MODELS[name] = {"model": model}
108
  logging.info(f"Model '{name}' loaded successfully.")
109
  except Exception as e:
110
  logging.error(f"Failed to load model '{name}': {e}", exc_info=True)
111
- _MODELS[name] = None
112
-
113
- return _MODELS.get(name)
114
-
115
- def get_gemini_client():
116
- """Initializes and returns the Gemini client."""
117
- global _TEMP_CRED_FILE
118
 
119
- if "gemini" not in _MODELS:
120
- # Write out credentials file if needed
121
  if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS:
122
- if not _TEMP_CRED_FILE or not os.path.exists(_TEMP_CRED_FILE):
123
  with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp:
124
  tfp.write(creds_json_str)
125
- _TEMP_CRED_FILE = tfp.name
126
- os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = _TEMP_CRED_FILE
127
-
128
  try:
129
  genai.configure(api_key=GOOGLE_API_KEY)
130
- _MODELS["gemini"] = genai.GenerativeModel(GEMINI_MODEL_NAME)
131
  except Exception as e:
132
  logging.error(f"Gemini init failed: {e}")
133
- _MODELS["gemini"] = None
134
-
135
- return _MODELS.get("gemini")
136
-
137
- def verify_medication_with_google(medication_name: str) -> Dict[str, Any]:
138
- """Verifies medication details using Google Search API."""
139
- if not medication_name:
140
- return {"error": "No medication name provided"}
141
-
142
- queries = [
143
- f"standard dosage for {medication_name}",
144
- f"side effects of {medication_name}",
145
- f"drug interactions with {medication_name}"
146
- ]
147
-
148
- search_results = google_search.search(queries, num_results=2)
149
- return {
150
- "medication": medication_name,
151
- "verification_results": search_results
152
- }
153
-
154
- def step1_run_donut(image: Image.Image) -> Dict[str, Any]:
155
- donut_components = get_model("donut")
156
- if not donut_components:
157
- return {"error": "Donut model not available."}
158
-
159
- model = donut_components["model"].to(DEVICE)
160
- processor = donut_components["processor"]
161
-
162
- task_prompt = "<s_cord-v2>"
163
- decoder_input_ids = processor.tokenizer(
164
- task_prompt, add_special_tokens=False, return_tensors="pt"
165
- ).input_ids.to(DEVICE)
166
- pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE)
167
-
168
- outputs = model.generate(
169
- pixel_values,
170
- decoder_input_ids=decoder_input_ids,
171
- max_length=model.decoder.config.max_position_embeddings,
172
- early_stopping=True,
173
- use_cache=True,
174
- num_beams=1,
175
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
176
- return_dict_in_generate=True,
177
- )
178
- sequence = (
179
- processor.batch_decode(outputs.sequences)[0]
180
- .replace(processor.tokenizer.eos_token, "")
181
- .replace(processor.tokenizer.pad_token, "")
182
- )
183
- return processor.token2json(sequence)
184
-
185
- def step2_run_phi3(medication_text: str) -> str:
186
- phi3_components = get_model("phi3")
187
- if not phi3_components:
188
- return medication_text
189
-
190
- pipe = phi3_components["model"]
191
- prompt = (
192
- f"Normalize the following prescription medication line into its components "
193
- f"(drug, dosage, frequency). Raw text: '{medication_text}'"
194
- )
195
- outputs = pipe(prompt, max_new_tokens=100, do_sample=False)
196
- return outputs[0]["generated_text"].split("Normalized:")[-1].strip()
197
-
198
- def step3_run_gemini_resolver(
199
- image: Image.Image, donut_result: Dict[str, Any], phi3_results: List[str]
200
- ) -> Dict[str, Any]:
201
- gemini_client = get_gemini_client()
202
- if not gemini_client:
203
- return {"error": "Gemini resolver not available."}
204
-
205
- prompt = f"""
206
- You are an expert pharmacist's assistant whose sole objective is to reconcile and verify prescription details by cross-referencing multiple AI model outputs against the original prescription image (the ultimate source of truth).
207
-
208
- Attached Inputs:
209
- 1. Prescription image file
210
- 2. Donut Model Output (layout-aware):
211
- {json.dumps(donut_result, indent=2)}
212
- 3. Phi-3 Model Output (medication refinement):
213
- {json.dumps(phi3_results, indent=2)}
214
-
215
- Please follow these steps **in order**:
216
 
217
- 1. **Extract & Normalize**
218
- - Read the prescription image and extract: Patient Name, Date, Age, Physician Name, and each listed medication.
219
- - Normalize dates to MM/DD/YYYY and medication names to their exact printed form.
220
-
221
- 2. **Compare Model Outputs**
222
- - For each field (Name, Date, Age, PhysicianName), check both model outputs and flag any discrepancies.
223
- - For each medication entry, compare `drug_raw`, `dosage`, and `frequency` from Phi-3 with the layout cues from Donut.
224
-
225
- 3. **Verify Against Image**
226
- - Wherever the two models disagree, use the image text as the tiebreaker.
227
- - If both models miss or misread something (e.g. a dosage or frequency), pull it directly from the image.
228
-
229
- 4. **Error Correction**
230
- - Correct spelling errors, unit inconsistencies (e.g. "mg" vs "MG"), and frequency shorthand (e.g. "BID" → "twice a day").
231
-
232
- 5. **Assemble Final JSON**
233
- - Populate exactly this schema; do not add extra keys.
234
- - If a field is unreadable or absent on the image, set its value to `null`.
235
-
236
- **Final JSON Schema**
237
- ```json
238
- {{
239
- "Name": "string or null",
240
- "Date": "string (MM/DD/YYYY) or null",
241
- "Age": "string or null",
242
- "PhysicianName": "string or null",
243
- "Medications": [
244
- {{
245
- "drug_raw": "string",
246
- "dosage": "string or null",
247
- "frequency": "string or null",
248
- "verification": {{
249
- "standard_dosage": "string or null",
250
- "side_effects": "string or null",
251
- "interactions": "string or null"
252
- }}
253
- }}
254
- ]
255
- }}
256
- """
257
-
258
- try:
259
- response = gemini_client.generate_content(
260
- [prompt, image],
261
- generation_config={"response_mime_type": "application/json"},
262
- )
263
- result = json.loads(response.text)
264
-
265
- # Enhance with Google Search verification
266
- for med in result.get("Medications", []):
267
- if drug_name := med.get("drug_raw"):
268
- verification = verify_medication_with_google(drug_name)
269
- med["verification"] = {
270
- "standard_dosage": self._extract_dosage_info(verification),
271
- "side_effects": self._extract_side_effects(verification),
272
- "interactions": self._extract_interactions(verification)
273
- }
274
-
275
- return result
276
- except Exception as e:
277
- logging.error(f"Gemini resolver failed: {e}")
278
- return {"error": f"Gemini failed to resolve data: {e}"}
279
-
280
- def _extract_dosage_info(verification_data: Dict) -> Optional[str]:
281
- """Extracts dosage information from verification results."""
282
- for result in verification_data.get("verification_results", []):
283
- if "standard dosage" in result.get("query", "").lower():
284
- return result.get("results", [{}])[0].get("snippet")
285
- return None
286
 
287
- def _extract_side_effects(verification_data: Dict) -> Optional[str]:
288
- """Extracts side effects information from verification results."""
289
- for result in verification_data.get("verification_results", []):
290
- if "side effects" in result.get("query", "").lower():
291
- return result.get("results", [{}])[0].get("snippet")
292
- return None
 
 
 
 
 
 
293
 
294
- def _extract_interactions(verification_data: Dict) -> Optional[str]:
295
- """Extracts drug interactions information from verification results."""
296
- for result in verification_data.get("verification_results", []):
297
- if "interactions" in result.get("query", "").lower():
298
- return result.get("results", [{}])[0].get("snippet")
299
- return None
300
 
301
  def extract_prescription_info(image_path: str) -> Dict[str, Any]:
302
- """Runs the full hybrid AI pipeline."""
303
- try:
304
- image = Image.open(image_path).convert("RGB")
305
-
306
- logging.info("Step 1: Running Donut model for layout analysis...")
307
- donut_data = step1_run_donut(image)
308
-
309
- medication_lines = [
310
- item.get("text", "")
311
- for item in donut_data.get("menu", [])
312
- if "medi" in item.get("category", "").lower()
313
- ]
314
-
315
- logging.info("Step 2: Running Phi-3 model for medication refinement...")
316
- phi3_refined_meds = [step2_run_phi3(line) for line in medication_lines]
317
-
318
- logging.info("Step 3: Running Gemini model as the expert resolver...")
319
- final_info = step3_run_gemini_resolver(image, donut_data, phi3_refined_meds)
320
-
321
- if final_info.get("error"):
322
- return final_info
323
-
324
- result = {
325
- "info": final_info,
326
- "error": None,
327
- "debug_info": {
328
- "donut_output": donut_data,
329
- "phi3_refinements": phi3_refined_meds,
330
- },
331
- }
332
- return result
333
-
334
- except Exception as e:
335
- logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True)
336
- return {"error": f"An unexpected error occurred in the pipeline: {e}"}
 
2
  import re
3
  import json
4
  import logging
 
 
5
  import tempfile
 
6
  import torch
7
+ import streamlit as st
8
+ from typing import Dict, Any, List
9
  from PIL import Image
 
 
10
 
11
  # Suppress verbose backend logs
12
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
 
13
 
14
  try:
15
+ from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel
16
  from huggingface_hub import login
17
  import google.generativeai as genai
18
  except ImportError as e:
 
20
  raise
21
 
22
  from config import (
23
+ HF_TOKEN, HF_MODELS, GOOGLE_API_KEY,
24
+ GOOGLE_APPLICATION_CREDENTIALS, GEMINI_MODEL_NAME, DEVICE, USE_GPU
 
 
 
 
 
 
 
25
  )
26
 
27
+ # Configure Logging & Auth
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
29
+ if HF_TOKEN: login(token=HF_TOKEN)
 
30
 
31
+ class PrescriptionProcessor:
32
+ """Encapsulates the entire hybrid pipeline to resolve the 'self' error."""
 
 
 
 
33
  def __init__(self):
34
+ self.model_cache = {}
35
+ self.temp_cred_file = None
36
+ self._load_all_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def _load_model(self, name: str):
39
+ if name in self.model_cache: return
 
40
  model_id = HF_MODELS.get(name)
41
+ if not model_id: return
 
42
 
43
  logging.info(f"Loading model '{name}' ({model_id}) to device '{DEVICE}'...")
44
  try:
 
46
  if name == "donut":
47
  processor = DonutProcessor.from_pretrained(model_id)
48
  model = VisionEncoderDecoderModel.from_pretrained(model_id, **quantization_config)
49
+ self.model_cache[name] = {"model": model, "processor": processor}
50
  elif name == "phi3":
51
+ model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, **quantization_config)
52
+ self.model_cache[name] = {"model": model}
 
 
 
 
 
 
53
  logging.info(f"Model '{name}' loaded successfully.")
54
  except Exception as e:
55
  logging.error(f"Failed to load model '{name}': {e}", exc_info=True)
 
 
 
 
 
 
 
56
 
57
+ def _load_gemini_client(self):
58
+ if "gemini" in self.model_cache: return
59
  if creds_json_str := GOOGLE_APPLICATION_CREDENTIALS:
60
+ if not self.temp_cred_file or not os.path.exists(self.temp_cred_file):
61
  with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tfp:
62
  tfp.write(creds_json_str)
63
+ self.temp_cred_file = tfp.name
64
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.temp_cred_file
 
65
  try:
66
  genai.configure(api_key=GOOGLE_API_KEY)
67
+ self.model_cache["gemini"] = genai.GenerativeModel(GEMINI_MODEL_NAME)
68
  except Exception as e:
69
  logging.error(f"Gemini init failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ def _load_all_models(self):
72
+ self._load_model("donut")
73
+ self._load_model("phi3")
74
+ self._load_gemini_client()
75
+
76
+ def _run_donut(self, image: Image.Image) -> Dict[str, Any]:
77
+ components = self.model_cache.get("donut")
78
+ if not components: return {"error": "Donut model not available."}
79
+ model, processor = components["model"].to(DEVICE), components["processor"]
80
+ task_prompt = "<s_cord-v2>"
81
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(DEVICE)
82
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(DEVICE)
83
+ outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True)
84
+ sequence = processor.batch_decode(outputs.sequences)[0].replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
85
+ return processor.token2json(sequence)
86
+
87
+ def _run_phi3(self, medication_text: str) -> str:
88
+ components = self.model_cache.get("phi3")
89
+ if not components: return medication_text
90
+ pipe = components["model"]
91
+ prompt = f"Normalize the following prescription medication line into its components (drug, dosage, frequency). Raw text: '{medication_text}'"
92
+ outputs = pipe(prompt, max_new_tokens=100, do_sample=False)
93
+ return outputs[0]['generated_text'].split("Normalized:")[-1].strip()
94
+
95
+ def _run_gemini_resolver(self, image: Image.Image, donut_result: Dict, phi3_results: List[str]) -> Dict[str, Any]:
96
+ gemini_client = self.model_cache.get("gemini")
97
+ if not gemini_client: return {"error": "Gemini resolver not available."}
98
+ prompt = f"""
99
+ You are an expert pharmacist’s assistant...
100
+ (Your detailed prompt from the previous turn goes here)
101
+ ...
102
+ **Final JSON Schema**
103
+ ```json
104
+ {{
105
+ "Name": "string or null", "Date": "string (MM/DD/YYYY) or null", "Age": "string or null", "PhysicianName": "string or null",
106
+ "Medications": [{{"drug_raw": "string", "dosage": "string or null", "frequency": "string or null"}}]
107
+ }}
108
+ ```
109
+ """
110
+ try:
111
+ response = gemini_client.generate_content([prompt, image], generation_config={"response_mime_type": "application/json"})
112
+ return json.loads(response.text)
113
+ except Exception as e:
114
+ logging.error(f"Gemini resolver failed: {e}")
115
+ # This is where your original error was being generated from
116
+ return {"error": f"Gemini failed to resolve data: {e}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ def process(self, image_path: str) -> Dict[str, Any]:
119
+ try:
120
+ image = Image.open(image_path).convert("RGB")
121
+ donut_data = self._run_donut(image)
122
+ med_lines = [item.get('text', '') for item in donut_data.get('menu', []) if 'medi' in item.get('category', '').lower()]
123
+ phi3_refined_meds = [self._run_phi3(line) for line in med_lines]
124
+ final_info = self._run_gemini_resolver(image, donut_data, phi3_refined_meds)
125
+ if final_info.get("error"): return final_info
126
+ return {"info": final_info, "error": None, "debug_info": {"donut_output": donut_data, "phi3_refinements": phi3_refined_meds}}
127
+ except Exception as e:
128
+ logging.error(f"Hybrid extraction pipeline failed: {e}", exc_info=True)
129
+ return {"error": f"An unexpected error occurred in the pipeline: {e}"}
130
 
131
+ @st.cache_resource
132
+ def get_processor():
133
+ return PrescriptionProcessor()
 
 
 
134
 
135
  def extract_prescription_info(image_path: str) -> Dict[str, Any]:
136
+ processor = get_processor()
137
+ return processor.process(image_path)