Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,388 +1,111 @@
|
|
| 1 |
-
#
|
|
|
|
| 2 |
import streamlit as st
|
| 3 |
-
|
| 4 |
-
import chromadb
|
| 5 |
-
from chromadb.utils import embedding_functions
|
| 6 |
from PIL import Image
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
import
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
| 50 |
-
]
|
| 51 |
-
GEMINI_ANALYSIS_PROMPT = """Analyze this medical image (e.g., pathology slide, diagram, scan).
|
| 52 |
-
Describe the key visual features relevant to a medical context.
|
| 53 |
-
Identify potential:
|
| 54 |
-
- Diseases or conditions indicated
|
| 55 |
-
- Pathological findings (e.g., cellular morphology, tissue structure, staining patterns)
|
| 56 |
-
- Visible cell types
|
| 57 |
-
- Relevant biomarkers (if inferable from staining or morphology)
|
| 58 |
-
- Anatomical context (if discernible)
|
| 59 |
-
|
| 60 |
-
Be concise and focus primarily on visually evident information. Avoid definitive diagnoses.
|
| 61 |
-
Structure the output clearly, perhaps using bullet points for findings.
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
-
# Chroma DB Configuration
|
| 65 |
-
CHROMA_PATH = "chroma_data_biobert" # Changed path to reflect model change
|
| 66 |
-
COLLECTION_NAME = "medical_docs_biobert" # Changed collection name
|
| 67 |
-
|
| 68 |
-
# --- Embedding Model Selection ---
|
| 69 |
-
# Using BioBERT v1.1 - Good domain knowledge, but potentially suboptimal for *semantic similarity search*.
|
| 70 |
-
# Default pooling (likely CLS token) will be used by sentence-transformers.
|
| 71 |
-
# Consider models fine-tuned for sentence similarity if retrieval quality is low:
|
| 72 |
-
# e.g., 'dmis-lab/sapbert-from-pubmedbert-sentencetransformer'
|
| 73 |
-
EMBEDDING_MODEL_NAME = "dmis-lab/biobert-v1.1"
|
| 74 |
-
CHROMA_DISTANCE_METRIC = "cosine" # Cosine is generally good for sentence embeddings
|
| 75 |
-
|
| 76 |
-
# --- Caching Resource Initialization ---
|
| 77 |
-
|
| 78 |
-
@st.cache_resource
|
| 79 |
-
def initialize_gemini_model() -> Optional[genai.GenerativeModel]:
|
| 80 |
-
"""Initializes and returns the Gemini Generative Model."""
|
| 81 |
-
try:
|
| 82 |
-
genai.configure(api_key=GOOGLE_API_KEY)
|
| 83 |
-
model = genai.GenerativeModel(
|
| 84 |
-
model_name=VISION_MODEL_NAME,
|
| 85 |
-
generation_config=GENERATION_CONFIG,
|
| 86 |
-
safety_settings=SAFETY_SETTINGS
|
| 87 |
-
)
|
| 88 |
-
logger.info(f"Successfully initialized Gemini Model: {VISION_MODEL_NAME}")
|
| 89 |
-
return model
|
| 90 |
-
except Exception as e:
|
| 91 |
-
err_msg = f"β Error initializing Gemini Model ({VISION_MODEL_NAME}): {e}"
|
| 92 |
-
st.error(err_msg) # Safe to call st.error here now
|
| 93 |
-
logger.error(err_msg, exc_info=True)
|
| 94 |
-
return None
|
| 95 |
-
|
| 96 |
-
@st.cache_resource
|
| 97 |
-
def initialize_embedding_function() -> Optional[embedding_functions.HuggingFaceEmbeddingFunction]:
|
| 98 |
-
"""Initializes and returns the Hugging Face Embedding Function."""
|
| 99 |
-
st.info(f"Initializing Embedding Model: {EMBEDDING_MODEL_NAME} (this may take a moment)...")
|
| 100 |
-
try:
|
| 101 |
-
# Pass HF_TOKEN if it exists (required for private/gated models)
|
| 102 |
-
embed_func = embedding_functions.HuggingFaceEmbeddingFunction(
|
| 103 |
-
api_key=HF_TOKEN, # Pass token here if needed by model
|
| 104 |
-
model_name=EMBEDDING_MODEL_NAME
|
| 105 |
-
)
|
| 106 |
-
logger.info(f"Successfully initialized HuggingFace Embedding Function: {EMBEDDING_MODEL_NAME}")
|
| 107 |
-
st.success(f"Embedding Model {EMBEDDING_MODEL_NAME} initialized.")
|
| 108 |
-
return embed_func
|
| 109 |
-
except Exception as e:
|
| 110 |
-
err_msg = f"β Error initializing HuggingFace Embedding Function ({EMBEDDING_MODEL_NAME}): {e}"
|
| 111 |
-
st.error(err_msg) # Safe here
|
| 112 |
-
logger.error(err_msg, exc_info=True)
|
| 113 |
-
st.info("βΉοΈ Make sure the embedding model name is correct and you have network access. "
|
| 114 |
-
"If using a private model, ensure HF_TOKEN is set in secrets. Check Space logs for details.")
|
| 115 |
-
return None
|
| 116 |
-
|
| 117 |
-
@st.cache_resource
|
| 118 |
-
def initialize_chroma_collection(_embedding_func: embedding_functions.EmbeddingFunction) -> Optional[chromadb.Collection]:
|
| 119 |
-
"""Initializes the Chroma DB client and returns the collection."""
|
| 120 |
-
if not _embedding_func:
|
| 121 |
-
st.error("β Cannot initialize Chroma DB without a valid embedding function.") # Safe here
|
| 122 |
-
return None
|
| 123 |
-
st.info(f"Initializing Chroma DB collection '{COLLECTION_NAME}'...")
|
| 124 |
-
try:
|
| 125 |
-
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
|
| 126 |
-
collection = chroma_client.get_or_create_collection(
|
| 127 |
-
name=COLLECTION_NAME,
|
| 128 |
-
embedding_function=_embedding_func, # Pass the initialized function
|
| 129 |
-
metadata={"hnsw:space": CHROMA_DISTANCE_METRIC}
|
| 130 |
-
)
|
| 131 |
-
logger.info(f"Chroma DB collection '{COLLECTION_NAME}' loaded/created at '{CHROMA_PATH}' using {CHROMA_DISTANCE_METRIC}.")
|
| 132 |
-
st.success(f"Chroma DB collection '{COLLECTION_NAME}' ready.")
|
| 133 |
-
return collection
|
| 134 |
-
except Exception as e:
|
| 135 |
-
err_msg = f"β Error initializing Chroma DB at '{CHROMA_PATH}': {e}"
|
| 136 |
-
st.error(err_msg) # Safe here
|
| 137 |
-
logger.error(err_msg, exc_info=True)
|
| 138 |
-
st.info(f"βΉοΈ Ensure the path '{CHROMA_PATH}' is writable. Check Space logs.")
|
| 139 |
-
return None
|
| 140 |
-
|
| 141 |
-
# --- Core Logic Functions (with Caching for Data Operations) ---
|
| 142 |
-
|
| 143 |
-
@st.cache_data(show_spinner=False) # Show spinner manually in UI
|
| 144 |
-
def analyze_image_with_gemini(_gemini_model: genai.GenerativeModel, image_bytes: bytes) -> Tuple[str, bool]:
|
| 145 |
-
"""
|
| 146 |
-
Analyzes image bytes with Gemini, returns (analysis_text, is_error).
|
| 147 |
-
Uses Streamlit's caching based on image_bytes.
|
| 148 |
-
"""
|
| 149 |
-
if not _gemini_model:
|
| 150 |
-
return "Error: Gemini model not initialized.", True
|
| 151 |
-
|
| 152 |
-
try:
|
| 153 |
-
img = Image.open(io.BytesIO(image_bytes))
|
| 154 |
-
response = _gemini_model.generate_content([GEMINI_ANALYSIS_PROMPT, img])
|
| 155 |
-
|
| 156 |
-
if not response.parts:
|
| 157 |
-
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
| 158 |
-
reason = response.prompt_feedback.block_reason
|
| 159 |
-
msg = f"Analysis blocked by safety settings: {reason}"
|
| 160 |
-
logger.warning(msg)
|
| 161 |
-
return msg, True # Indicate block/error state
|
| 162 |
-
else:
|
| 163 |
-
msg = "Error: Gemini analysis returned no content (empty or invalid response)."
|
| 164 |
-
logger.error(msg)
|
| 165 |
-
return msg, True
|
| 166 |
-
logger.info("Gemini analysis successful.")
|
| 167 |
-
return response.text, False # Indicate success
|
| 168 |
-
|
| 169 |
-
except genai.types.BlockedPromptException as e:
|
| 170 |
-
msg = f"Analysis blocked (prompt issue): {e}"
|
| 171 |
-
logger.warning(msg)
|
| 172 |
-
return msg, True
|
| 173 |
-
except Exception as e:
|
| 174 |
-
msg = f"Error during Gemini analysis: {e}"
|
| 175 |
-
logger.error(msg, exc_info=True)
|
| 176 |
-
return msg, True
|
| 177 |
-
|
| 178 |
-
@st.cache_data(show_spinner=False)
|
| 179 |
-
def query_chroma(_collection: chromadb.Collection, query_text: str, n_results: int = 5) -> Optional[Dict[str, List[Any]]]:
|
| 180 |
-
"""Queries Chroma DB, returns results dict or None on error."""
|
| 181 |
-
if not _collection:
|
| 182 |
-
logger.error("Query attempt failed: Chroma collection is not available.")
|
| 183 |
-
return None
|
| 184 |
-
if not query_text:
|
| 185 |
-
logger.warning("Attempted to query Chroma with empty text.")
|
| 186 |
-
return None
|
| 187 |
-
try:
|
| 188 |
-
refined_query = query_text # Using direct analysis text for now
|
| 189 |
-
|
| 190 |
-
results = _collection.query(
|
| 191 |
-
query_texts=[refined_query],
|
| 192 |
-
n_results=n_results,
|
| 193 |
-
include=['documents', 'metadatas', 'distances']
|
| 194 |
-
)
|
| 195 |
-
logger.info(f"Chroma query successful for text snippet: '{query_text[:50]}...'")
|
| 196 |
-
return results
|
| 197 |
-
except Exception as e:
|
| 198 |
-
# Show error in UI as well
|
| 199 |
-
st.error(f"β Error querying Chroma DB: {e}", icon="π¨")
|
| 200 |
-
logger.error(f"Error querying Chroma DB: {e}", exc_info=True)
|
| 201 |
-
return None
|
| 202 |
-
|
| 203 |
-
def add_dummy_data_to_chroma(collection: chromadb.Collection, embedding_func: embedding_functions.EmbeddingFunction):
|
| 204 |
-
"""Adds example medical text snippets to Chroma using the provided embedding function."""
|
| 205 |
-
if not collection or not embedding_func:
|
| 206 |
-
st.error("β Cannot add dummy data: Chroma Collection or Embedding Function not available.")
|
| 207 |
-
return
|
| 208 |
-
|
| 209 |
-
# Check if dummy data needs adding first to avoid unnecessary processing
|
| 210 |
-
docs_to_check = [
|
| 211 |
-
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive."
|
| 212 |
-
] # Only check one doc for speed
|
| 213 |
-
try:
|
| 214 |
-
existing_check = collection.get(where={"document": docs_to_check[0]}, limit=1, include=[])
|
| 215 |
-
if existing_check and existing_check.get('ids'):
|
| 216 |
-
st.info("Dummy data seems to already exist. Skipping add.")
|
| 217 |
-
logger.info("Skipping dummy data addition as it likely exists.")
|
| 218 |
-
return
|
| 219 |
-
except Exception as e:
|
| 220 |
-
logger.warning(f"Could not efficiently check for existing dummy data: {e}. Proceeding with add attempt.")
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
status = st.status(f"Adding dummy data (using {EMBEDDING_MODEL_NAME})...", expanded=True)
|
| 224 |
-
try:
|
| 225 |
-
# --- Dummy Data Definition ---
|
| 226 |
-
docs = [
|
| 227 |
-
"Figure 1A shows adenocarcinoma of the lung, papillary subtype. Note the glandular structures and nuclear atypia. TTF-1 staining was positive.",
|
| 228 |
-
"Pathology slide 34B demonstrates high-grade glioma (glioblastoma) with significant necrosis and microvascular proliferation. Ki-67 index was high.",
|
| 229 |
-
"This diagram illustrates the EGFR signaling pathway and common mutation sites targeted by tyrosine kinase inhibitors in non-small cell lung cancer.",
|
| 230 |
-
"Micrograph showing chronic gastritis with Helicobacter pylori organisms (visible with special stain, not shown here). Mild intestinal metaplasia is present.",
|
| 231 |
-
"Slide CJD-Sample-02: Spongiform changes characteristic of prion disease are evident in the cerebral cortex. Gliosis is also noted."
|
| 232 |
-
]
|
| 233 |
-
metadatas = [
|
| 234 |
-
{"source": "Example Paper 1", "topic": "Lung Cancer Pathology", "entities": "adenocarcinoma, lung cancer, glandular structures, nuclear atypia, papillary subtype, TTF-1", "IMAGE_ID": "fig_1a_adeno_lung.png"},
|
| 235 |
-
{"source": "Path Report 789", "topic": "Brain Tumor Pathology", "entities": "high-grade glioma, glioblastoma, necrosis, microvascular proliferation, Ki-67", "IMAGE_ID": "slide_34b_gbm.tiff"},
|
| 236 |
-
{"source": "Textbook Chapter 5", "topic": "Molecular Oncology Pathways", "entities": "EGFR, tyrosine kinase inhibitors, non-small cell lung cancer", "IMAGE_ID": "diagram_egfr_pathway.svg"},
|
| 237 |
-
{"source": "Path Report 101", "topic": "Gastrointestinal Pathology", "entities": "chronic gastritis, Helicobacter pylori, intestinal metaplasia", "IMAGE_ID": "micrograph_h_pylori_gastritis.jpg"},
|
| 238 |
-
{"source": "Case Study CJD", "topic": "Neuropathology", "entities": "prion disease, Spongiform changes, Gliosis, cerebral cortex", "IMAGE_ID": "slide_cjd_sample_02.jpg"}
|
| 239 |
-
]
|
| 240 |
-
# Ensure IDs are unique even if run close together
|
| 241 |
-
base_id = f"doc_biobert_{int(time.time() * 1000)}"
|
| 242 |
-
ids = [f"{base_id}_{i}" for i in range(len(docs))]
|
| 243 |
-
|
| 244 |
-
status.update(label=f"Generating embeddings & adding {len(docs)} documents (this uses BioBERT and may take time)...")
|
| 245 |
-
|
| 246 |
-
# Embeddings are generated implicitly by ChromaDB during .add()
|
| 247 |
-
collection.add(
|
| 248 |
-
documents=docs,
|
| 249 |
-
metadatas=metadatas,
|
| 250 |
-
ids=ids
|
| 251 |
-
)
|
| 252 |
-
status.update(label=f"β
Added {len(docs)} dummy documents.", state="complete", expanded=False)
|
| 253 |
-
logger.info(f"Added {len(docs)} dummy documents to collection '{COLLECTION_NAME}'.")
|
| 254 |
-
|
| 255 |
-
except Exception as e:
|
| 256 |
-
err_msg = f"Error adding dummy data to Chroma: {e}"
|
| 257 |
-
status.update(label=f"β Error: {err_msg}", state="error", expanded=True)
|
| 258 |
-
logger.error(err_msg, exc_info=True)
|
| 259 |
-
|
| 260 |
-
# --- Initialize Resources ---
|
| 261 |
-
# These calls use @st.cache_resource, run only once unless cleared/changed.
|
| 262 |
-
# Order matters if one depends on another (embedding func needed for chroma).
|
| 263 |
-
gemini_model = initialize_gemini_model()
|
| 264 |
-
embedding_func = initialize_embedding_function()
|
| 265 |
-
collection = initialize_chroma_collection(embedding_func) # Pass embedding func
|
| 266 |
-
|
| 267 |
-
# --- Streamlit UI ---
|
| 268 |
-
# set_page_config() is already called at the top
|
| 269 |
-
|
| 270 |
-
st.title("βοΈ Medical Image Analysis & RAG (BioBERT Embeddings)")
|
| 271 |
-
|
| 272 |
-
# --- DISCLAIMER ---
|
| 273 |
-
st.warning("""
|
| 274 |
-
**β οΈ Disclaimer:** This tool is for demonstration and informational purposes ONLY.
|
| 275 |
-
It is **NOT** a medical device and should **NOT** be used for actual medical diagnosis, treatment, or decision-making.
|
| 276 |
-
AI analysis can be imperfect. Always consult with qualified healthcare professionals for any medical concerns.
|
| 277 |
-
Do **NOT** upload identifiable patient data (PHI). Analysis quality depends heavily on the chosen embedding model.
|
| 278 |
-
""", icon="β£οΈ")
|
| 279 |
-
|
| 280 |
-
st.markdown(f"""
|
| 281 |
-
Upload a medical image. Gemini Vision will analyze it. Related information
|
| 282 |
-
will be retrieved from a Chroma DB knowledge base using **{EMBEDDING_MODEL_NAME}** embeddings.
|
| 283 |
-
""")
|
| 284 |
-
|
| 285 |
-
# Sidebar
|
| 286 |
-
with st.sidebar:
|
| 287 |
-
st.header("βοΈ Controls")
|
| 288 |
-
uploaded_file = st.file_uploader(
|
| 289 |
-
"Choose an image...",
|
| 290 |
-
type=["jpg", "jpeg", "png", "tiff", "webp"],
|
| 291 |
-
help="Upload a medical image file (e.g., pathology, diagram)."
|
| 292 |
)
|
|
|
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
if collection and embedding_func:
|
| 298 |
-
add_dummy_data_to_chroma(collection, embedding_func)
|
| 299 |
-
else:
|
| 300 |
-
st.error("β Cannot add dummy data: Chroma Collection or Embedding Function failed to initialize.")
|
| 301 |
-
|
| 302 |
-
st.divider()
|
| 303 |
-
|
| 304 |
-
st.header("βΉοΈ System Info")
|
| 305 |
-
st.caption(f"**Gemini Model:** `{VISION_MODEL_NAME}`")
|
| 306 |
-
st.caption(f"**Embedding Model:** `{EMBEDDING_MODEL_NAME}`")
|
| 307 |
-
st.caption(f"**Chroma Collection:** `{COLLECTION_NAME}`")
|
| 308 |
-
st.caption(f"**Chroma Path:** `{CHROMA_PATH}`")
|
| 309 |
-
st.caption(f"**Distance Metric:** `{CHROMA_DISTANCE_METRIC}`")
|
| 310 |
-
st.caption(f"**Google API Key:** {'Set' if GOOGLE_API_KEY else 'Not Set'}")
|
| 311 |
-
st.caption(f"**HF Token:** {'Provided' if HF_TOKEN else 'Not Provided'}")
|
| 312 |
-
|
| 313 |
-
# Main Display Area
|
| 314 |
-
col1, col2 = st.columns(2)
|
| 315 |
-
|
| 316 |
-
with col1:
|
| 317 |
-
st.subheader("πΌοΈ Uploaded Image")
|
| 318 |
-
if uploaded_file is not None:
|
| 319 |
-
image_bytes = uploaded_file.getvalue()
|
| 320 |
-
st.image(image_bytes, caption=f"Uploaded: {uploaded_file.name}", use_column_width=True)
|
| 321 |
else:
|
| 322 |
-
st.
|
| 323 |
-
|
| 324 |
-
with
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
st.
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
st.info("Analysis results will appear here once an image is uploaded.")
|
| 380 |
-
else:
|
| 381 |
-
# Initialization error occurred earlier, resources might be None
|
| 382 |
-
st.error("β Analysis cannot proceed. Check if Gemini model or Chroma DB failed to initialize (see sidebar info & Space logs).")
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
st.markdown("---")
|
| 386 |
-
st.markdown("<div style='text-align: center; font-size: small;'>Powered by Google Gemini, Chroma DB, Hugging Face, and Streamlit</div>", unsafe_allow_html=True)
|
| 387 |
-
|
| 388 |
-
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import os
|
| 3 |
import streamlit as st
|
| 4 |
+
from streamlit_drawable_canvas import st_canvas
|
|
|
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
+
import openai
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
# βββ 1. Configuration & Secrets βββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
openai.api_key = st.secrets["OPENAI_API_KEY"] # or os.getenv("OPENAI_API_KEY")
|
| 12 |
+
st.set_page_config(
|
| 13 |
+
page_title="MedSketchβ―AI",
|
| 14 |
+
layout="wide",
|
| 15 |
+
initial_sidebar_state="expanded",
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# βββ 2. Sidebar: Settings & Metadata ββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
st.sidebar.header("βοΈ Settings")
|
| 20 |
+
model_choice = st.sidebar.selectbox(
|
| 21 |
+
"Model",
|
| 22 |
+
["GPT-4o (API)", "Stable Diffusion LoRA"],
|
| 23 |
+
index=0
|
| 24 |
+
)
|
| 25 |
+
style_preset = st.sidebar.radio(
|
| 26 |
+
"Preset Style",
|
| 27 |
+
["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"]
|
| 28 |
+
)
|
| 29 |
+
strength = st.sidebar.slider("Stylization Strength", 0.1, 1.0, 0.7)
|
| 30 |
+
|
| 31 |
+
st.sidebar.markdown("---")
|
| 32 |
+
st.sidebar.header("π Metadata")
|
| 33 |
+
patient_id = st.sidebar.text_input("Patient / Case ID")
|
| 34 |
+
roi = st.sidebar.text_input("Region of Interest")
|
| 35 |
+
umls_code = st.sidebar.text_input("UMLS / SNOMED CT Code")
|
| 36 |
+
|
| 37 |
+
# βββ 3. Main: Prompt Input & Batch Generation βββββββββββββββββββββββββββββββ
|
| 38 |
+
st.title("πΌοΈ MedSketchβ―AI β Advanced Clinical Diagram Generator")
|
| 39 |
+
|
| 40 |
+
with st.expander("π Enter Prompts (one per line for batch)"):
|
| 41 |
+
raw = st.text_area(
|
| 42 |
+
"Describe what you need:",
|
| 43 |
+
placeholder=(
|
| 44 |
+
"e.g. βGenerate a labeled crossβsection of the human heart with chamber names, valves, and flow arrowsβ¦β\n"
|
| 45 |
+
"e.g. βProduce a stylized H&E stain of liver tissue highlighting portal triadsβ¦β"
|
| 46 |
+
),
|
| 47 |
+
height=120
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
)
|
| 49 |
+
prompts = [p.strip() for p in raw.splitlines() if p.strip()]
|
| 50 |
|
| 51 |
+
if st.button("π Generate"):
|
| 52 |
+
if not prompts:
|
| 53 |
+
st.error("Please enter at least one prompt.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
else:
|
| 55 |
+
cols = st.columns(min(3, len(prompts)))
|
| 56 |
+
for i, prompt in enumerate(prompts):
|
| 57 |
+
with st.spinner(f"Rendering image {i+1}/{len(prompts)}β¦"):
|
| 58 |
+
if model_choice == "GPT-4o (API)":
|
| 59 |
+
resp = openai.Image.create(
|
| 60 |
+
model="gpt-4o",
|
| 61 |
+
prompt=f"[{style_preset} | strength={strength}] {prompt}",
|
| 62 |
+
size="1024x1024"
|
| 63 |
+
)
|
| 64 |
+
img_data = requests.get(resp["data"][0]["url"]).content
|
| 65 |
+
else:
|
| 66 |
+
# stub for Stable Diffusion LoRA
|
| 67 |
+
img_data = generate_sd_image(prompt, style=style_preset, strength=strength)
|
| 68 |
+
img = Image.open(BytesIO(img_data))
|
| 69 |
+
|
| 70 |
+
# Display + Download
|
| 71 |
+
with cols[i]:
|
| 72 |
+
st.image(img, use_column_width=True, caption=prompt)
|
| 73 |
+
buf = BytesIO()
|
| 74 |
+
img.save(buf, format="PNG")
|
| 75 |
+
st.download_button(
|
| 76 |
+
label="β¬οΈ Download PNG",
|
| 77 |
+
data=buf.getvalue(),
|
| 78 |
+
file_name=f"medsketch_{i+1}.png",
|
| 79 |
+
mime="image/png"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# βββ Annotation Canvas βββββββββββββββββββββββββββ
|
| 83 |
+
st.markdown("**βοΈ Annotate:**")
|
| 84 |
+
canvas_res = st_canvas(
|
| 85 |
+
fill_color="rgba(255, 0, 0, 0.3)", # annotation color
|
| 86 |
+
stroke_width=2,
|
| 87 |
+
background_image=img,
|
| 88 |
+
update_streamlit=True,
|
| 89 |
+
height=512,
|
| 90 |
+
width=512,
|
| 91 |
+
drawing_mode="freedraw",
|
| 92 |
+
key=f"canvas_{i}"
|
| 93 |
+
)
|
| 94 |
+
# Save annotations
|
| 95 |
+
if canvas_res.json_data:
|
| 96 |
+
ann = canvas_res.json_data["objects"]
|
| 97 |
+
st.session_state.setdefault("annotations", {})[prompt] = ann
|
| 98 |
+
|
| 99 |
+
# βββ 4. History & Exports βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
if "annotations" in st.session_state:
|
| 101 |
+
st.markdown("---")
|
| 102 |
+
st.subheader("π Session History & Annotations")
|
| 103 |
+
for prm, objs in st.session_state["annotations"].items():
|
| 104 |
+
st.markdown(f"**Prompt:** {prm}")
|
| 105 |
+
st.json(objs)
|
| 106 |
+
st.download_button(
|
| 107 |
+
"β¬οΈ Export All Annotations (JSON)",
|
| 108 |
+
data=json.dumps(st.session_state["annotations"], indent=2),
|
| 109 |
+
file_name="medsketch_annotations.json",
|
| 110 |
+
mime="application/json"
|
| 111 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|