Spaces:
Sleeping
Sleeping
| # app.py | |
| """ | |
| MedSketch AI: Advanced Clinical Diagram Generator | |
| A Streamlit application leveraging AI models (GPT-4o, potentially Stable Diffusion) | |
| to generate medical diagrams based on user prompts, with options for styling, | |
| metadata association, and annotations. | |
| """ | |
| import os | |
| import json | |
| import logging | |
| from io import BytesIO | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import streamlit as st | |
| from streamlit_drawable_canvas import st_canvas | |
| from PIL import Image | |
| import openai | |
| from openai import OpenAI, OpenAIError # Use modern OpenAI client and error types | |
| # βββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| APP_TITLE = "MedSketch AI β Advanced Clinical Diagram Generator" | |
| DEFAULT_MODEL = "GPT-4o (Vision)" # Updated model name | |
| STABLE_DIFFUSION_MODEL = "Stable Diffusion LoRA" # Placeholder name | |
| MODEL_OPTIONS = [DEFAULT_MODEL, STABLE_DIFFUSION_MODEL] | |
| STYLE_PRESETS = ["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"] | |
| DEFAULT_STYLE = "Anatomical Diagram" | |
| DEFAULT_STRENGTH = 0.7 | |
| IMAGE_SIZE = "1024x1024" | |
| CANVAS_SIZE = 512 | |
| ANNOTATION_COLOR = "rgba(255, 0, 0, 0.3)" # Red with transparency | |
| ANNOTATION_STROKE_WIDTH = 2 | |
| SESSION_STATE_ANNOTATIONS = "medsketch_annotations" | |
| SESSION_STATE_HISTORY = "medsketch_history" # Store generated images too | |
| # βββ Setup & Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| st.set_page_config( | |
| page_title=APP_TITLE, | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| menu_items={ | |
| 'About': f"{APP_TITLE} - Generates medical diagrams using AI.", | |
| 'Get Help': None, # Add a link if you have one | |
| 'Report a bug': None # Add a link if you have one | |
| } | |
| ) | |
| # Initialize OpenAI Client (Best Practice) | |
| # Use st.secrets for deployment, fallback to env var for local dev | |
| api_key = st.secrets.get("OPENAI_API_KEY", os.getenv("OPENAI_API_KEY")) | |
| if not api_key: | |
| st.error("π¨ OpenAI API Key not found! Please set it in Streamlit secrets or environment variables.", icon="π¨") | |
| st.stop() # Halt execution if no key | |
| try: | |
| client = OpenAI(api_key=api_key) | |
| logger.info("OpenAI client initialized successfully.") | |
| except Exception as e: | |
| st.error(f"π¨ Failed to initialize OpenAI client: {e}", icon="π¨") | |
| logger.exception("OpenAI client initialization failed.") | |
| st.stop() | |
| # βββ Helper Functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_openai_image(prompt: str, style: str, strength: float) -> Image.Image: | |
| """ | |
| Generates an image using the OpenAI API (GPT-4o). | |
| Args: | |
| prompt: The user's text prompt. | |
| style: The selected style preset. | |
| strength: The stylization strength (conceptually used in prompt). | |
| Returns: | |
| A PIL Image object. | |
| Raises: | |
| OpenAIError: If the API call fails. | |
| IOError: If the image data cannot be processed. | |
| """ | |
| logger.info(f"Requesting OpenAI image generation for prompt: '{prompt}' with style '{style}'") | |
| full_prompt = f"Style: [{style}], Strength: [{strength:.2f}] - Generate the following medical illustration: {prompt}" | |
| try: | |
| response = client.images.generate( | |
| model="dall-e-3", # Or "gpt-4o" if/when available via this endpoint. DALL-E 3 is current standard. | |
| prompt=full_prompt, | |
| size=IMAGE_SIZE, | |
| quality="standard", # or "hd" | |
| n=1, | |
| response_format="url" # Or "b64_json" to avoid a second request | |
| ) | |
| image_url = response.data[0].url | |
| logger.info(f"Image generated successfully, URL: {image_url}") | |
| # Fetch the image data from the URL | |
| # Note: Using response_format="b64_json" would avoid this extra step | |
| import requests # Need to import requests library | |
| image_response = requests.get(image_url, timeout=30) # Add timeout | |
| image_response.raise_for_status() # Check for HTTP errors | |
| img_data = BytesIO(image_response.content) | |
| img = Image.open(img_data) | |
| return img | |
| except OpenAIError as e: | |
| logger.error(f"OpenAI API error: {e}") | |
| st.error(f"β OpenAI API Error: {e}", icon="β") | |
| raise | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to download image from URL {image_url}: {e}") | |
| st.error(f"β Network Error: Failed to download image. {e}", icon="β") | |
| raise IOError(f"Failed to download image: {e}") from e | |
| except Exception as e: | |
| logger.exception(f"An unexpected error occurred during OpenAI image generation: {e}") | |
| st.error(f"β An unexpected error occurred: {e}", icon="β") | |
| raise | |
| def generate_sd_image(prompt: str, style: str, strength: float) -> Image.Image: | |
| """ | |
| Placeholder for generating an image using a Stable Diffusion LoRA model. | |
| Replace this with your actual implementation. | |
| Args: | |
| prompt: The user's text prompt. | |
| style: The selected style preset. | |
| strength: The stylization strength. | |
| Returns: | |
| A PIL Image object (dummy implementation). | |
| Raises: | |
| NotImplementedError: As this is a placeholder. | |
| """ | |
| logger.warning("Stable Diffusion LoRA model generation is not implemented. Returning placeholder.") | |
| st.warning("π§ Stable Diffusion LoRA generation is not yet implemented. Using placeholder.", icon="π§") | |
| # --- Placeholder Implementation --- | |
| # Replace this with actual SD model call | |
| # For now, create a simple dummy image with text | |
| img = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), color = (210, 210, 210)) | |
| from PIL import ImageDraw | |
| d = ImageDraw.Draw(img) | |
| d.text((10,10), f"Stable Diffusion Placeholder\nStyle: {style}\nPrompt: {prompt[:50]}...", fill=(0,0,0)) | |
| # --- End Placeholder --- | |
| # Simulate some processing time | |
| import time | |
| time.sleep(1) | |
| return img | |
| # raise NotImplementedError("Stable Diffusion LoRA generation is not yet available.") | |
| def display_result(image: Image.Image, prompt: str, index: int, total: int) -> Optional[List[Dict[str, Any]]]: | |
| """ | |
| Displays a generated image, download button, and annotation canvas. | |
| Args: | |
| image: The PIL Image to display. | |
| prompt: The prompt used to generate the image. | |
| index: The index of the current image in a batch. | |
| total: The total number of images in the batch. | |
| Returns: | |
| Annotation data (list of dicts) if annotations were made, otherwise None. | |
| """ | |
| st.image(image, caption=f"Result {index + 1}/{total}: {prompt}", use_container_width=True) | |
| # Prepare image for download | |
| buf = BytesIO() | |
| image.save(buf, format="PNG") | |
| buf.seek(0) | |
| st.download_button( | |
| label="β¬οΈ Download PNG", | |
| data=buf, | |
| file_name=f"medsketch_{index+1}_{prompt[:20].replace(' ', '_')}.png", | |
| mime="image/png", | |
| key=f"download_{index}" | |
| ) | |
| # Annotation Canvas | |
| st.markdown("**βοΈ Annotate:**") | |
| # Resize image for canvas if needed, maintaining aspect ratio (optional) | |
| # For simplicity, we assume the canvas size matches desired annotation size | |
| canvas_image = image.copy() | |
| canvas_image.thumbnail((CANVAS_SIZE, CANVAS_SIZE)) | |
| canvas_result = st_canvas( | |
| fill_color=ANNOTATION_COLOR, | |
| stroke_width=ANNOTATION_STROKE_WIDTH, | |
| background_image=canvas_image, | |
| update_streamlit=True, # Update in real-time | |
| height=canvas_image.height, | |
| width=canvas_image.width, | |
| drawing_mode="freedraw", # Or choose other modes like "line", "rect", etc. | |
| key=f"canvas_{index}" | |
| ) | |
| if canvas_result.json_data and canvas_result.json_data.get("objects"): | |
| return canvas_result.json_data["objects"] | |
| return None | |
| # βββ Initialize Session State βββββββββββββββββββββββββββββββββββββββββββββββ | |
| if SESSION_STATE_ANNOTATIONS not in st.session_state: | |
| st.session_state[SESSION_STATE_ANNOTATIONS] = {} # Dict[prompt, List[annotation_objects]] | |
| if SESSION_STATE_HISTORY not in st.session_state: | |
| st.session_state[SESSION_STATE_HISTORY] = [] # List[Dict[str, Any]] storing generation results | |
| # βββ Sidebar: Settings & Metadata βββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.header("βοΈ Generation Settings") | |
| model_choice = st.selectbox( | |
| "Select Model", | |
| options=MODEL_OPTIONS, | |
| index=MODEL_OPTIONS.index(DEFAULT_MODEL), | |
| help="Choose the AI model for image generation." | |
| ) | |
| style_preset = st.radio( | |
| "Select Preset Style", | |
| options=STYLE_PRESETS, | |
| index=STYLE_PRESETS.index(DEFAULT_STYLE), | |
| horizontal=True, # More compact layout | |
| help="Apply a predefined visual style to the generation." | |
| ) | |
| # Allow custom style input only if "Custom" is selected | |
| custom_style_input = "" | |
| if style_preset == "Custom": | |
| custom_style_input = st.text_input("Enter Custom Style Description:", key="custom_style") | |
| final_style = custom_style_input if style_preset == "Custom" else style_preset | |
| strength = st.slider( | |
| "Stylization Strength", | |
| min_value=0.1, | |
| max_value=1.0, | |
| value=DEFAULT_STRENGTH, | |
| step=0.05, | |
| help="Controls how strongly the chosen style influences the result (conceptual)." | |
| ) | |
| st.markdown("---") | |
| st.header("π Optional Metadata") | |
| patient_id = st.text_input("Patient / Case ID", key="patient_id", help="Associate with a specific patient or case.") | |
| roi = st.text_input("Region of Interest (ROI)", key="roi", help="Specify the anatomical region shown.") | |
| umls_code = st.text_input("UMLS / SNOMED CT Code", key="umls_code", help="Link to relevant medical ontology codes.") | |
| # Add a clear history button | |
| st.markdown("---") | |
| if st.button("β οΈ Clear History & Annotations", help="Removes all generated images and annotations from this session."): | |
| st.session_state[SESSION_STATE_ANNOTATIONS] = {} | |
| st.session_state[SESSION_STATE_HISTORY] = [] | |
| st.rerun() # Refresh the page to reflect cleared state | |
| # βββ Main Application Area βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.title(APP_TITLE) | |
| st.markdown("Generate medical illustrations from text descriptions using AI. Annotate and export your results.") | |
| # --- Prompt Input Area --- | |
| prompt_input_area = st.container() | |
| with prompt_input_area: | |
| st.subheader("π Enter Prompt(s)") | |
| st.caption("Enter one prompt per line to generate multiple images in a batch.") | |
| raw_prompts = st.text_area( | |
| "Describe the medical diagram(s) you need:", | |
| placeholder=( | |
| "Example 1: A sagittal view of the human knee joint, labeling the ACL, PCL, meniscus, femur, and tibia.\n" | |
| "Example 2: High-power field H&E stain of lung adenocarcinoma showing glandular formation.\n" | |
| "Example 3: Immunohistochemistry (IHC) stain for PD-L1 in tonsil tissue, showing positive staining on immune cells." | |
| ), | |
| height=150, # Slightly larger height | |
| label_visibility="collapsed" | |
| ) | |
| prompts: List[str] = [p.strip() for p in raw_prompts.splitlines() if p.strip()] | |
| # --- Generation Trigger --- | |
| generate_button = st.button( | |
| f"π Generate Diagram{'s' if len(prompts) > 1 else ''}", | |
| type="primary", | |
| disabled=not prompts, # Disable if no prompts | |
| use_container_width=True | |
| ) | |
| # --- Generation and Display Area --- | |
| results_area = st.container() | |
| if generate_button: | |
| if not prompts: | |
| st.warning("β οΈ Please enter at least one prompt description.", icon="β οΈ") | |
| else: | |
| logger.info(f"Starting generation for {len(prompts)} prompts using model '{model_choice}'.") | |
| num_prompts = len(prompts) | |
| max_cols = 3 # Adjust number of columns based on screen width or preference | |
| cols = st.columns(min(max_cols, num_prompts)) | |
| # Use a progress bar for batch generation | |
| progress_bar = st.progress(0, text=f"Initializing generation...") | |
| for i, prompt in enumerate(prompts): | |
| col_index = i % max_cols | |
| with cols[col_index]: | |
| st.markdown(f"--- \n**Processing: {i+1}/{num_prompts}**") | |
| spinner_msg = f"Generating image {i+1}/{num_prompts} for prompt: \"{prompt[:50]}...\"" | |
| with st.spinner(spinner_msg): | |
| try: | |
| # Select generation function based on model choice | |
| if model_choice == DEFAULT_MODEL: | |
| generated_image = generate_openai_image(prompt, final_style, strength) | |
| elif model_choice == STABLE_DIFFUSION_MODEL: | |
| generated_image = generate_sd_image(prompt, final_style, strength) | |
| else: | |
| st.error(f"Unknown model selected: {model_choice}", icon="β") | |
| continue # Skip to next prompt | |
| # Display result and get annotations | |
| annotations = display_result(generated_image, prompt, i, num_prompts) | |
| # Store results and annotations in session state | |
| result_data = { | |
| "prompt": prompt, | |
| "model": model_choice, | |
| "style": final_style, | |
| "strength": strength, | |
| "metadata": { | |
| "patient_id": patient_id, | |
| "roi": roi, | |
| "umls_code": umls_code, | |
| }, | |
| # Store image data efficiently (e.g., as base64 or keep PIL object if memory allows) | |
| # For simplicity here, we might just store prompt and annotations. | |
| # Storing images in session state can consume a lot of memory. | |
| # Let's store the prompt reference and annotations. | |
| "image_ref_index": i # Reference to this generation instance | |
| } | |
| st.session_state[SESSION_STATE_HISTORY].append(result_data) | |
| if annotations: | |
| st.session_state[SESSION_STATE_ANNOTATIONS][prompt] = annotations | |
| st.success(f"Annotations saved for prompt {i+1}.", icon="β ") | |
| except (OpenAIError, IOError, NotImplementedError, Exception) as e: | |
| # Errors are logged and displayed by the generation functions | |
| st.error(f"Failed to generate image for prompt: '{prompt}'. Error: {e}", icon="π₯") | |
| # Optionally add failed attempts to history? | |
| st.session_state[SESSION_STATE_HISTORY].append({ | |
| "prompt": prompt, "status": "failed", "error": str(e) | |
| }) | |
| # Update progress bar | |
| progress_val = (i + 1) / num_prompts | |
| progress_bar.progress(progress_val, text=f"Generated {i+1}/{num_prompts} images...") | |
| progress_bar.progress(1.0, text="Batch generation complete!") | |
| st.toast(f"Finished generating {num_prompts} image(s)!", icon="π") | |
| # Explicitly clear the progress bar after completion | |
| # (Streamlit often handles this, but explicit removal can be cleaner) | |
| # Consider removing or hiding the progress bar element if needed after completion. | |
| # βββ History & Exports Section βββββββββββββββββββββββββββββββββββββββββββββββ | |
| history_area = st.container() | |
| with history_area: | |
| # Use session state history which is more robust | |
| if st.session_state[SESSION_STATE_HISTORY]: | |
| st.markdown("---") | |
| st.subheader("π Session History & Annotations") | |
| st.caption("Review generated images (if stored) and their annotations from this session.") | |
| # Display stored history (simplified view focusing on annotations) | |
| for idx, item in enumerate(st.session_state[SESSION_STATE_HISTORY]): | |
| if item.get("status") == "failed": | |
| st.warning(f"**Prompt {idx+1} (Failed):** {item['prompt']} \n *Error: {item['error']}*", icon="β οΈ") | |
| else: | |
| prompt_key = item["prompt"] | |
| st.markdown(f"**Prompt {idx+1}:** `{prompt_key}`") | |
| st.markdown(f"*Model: {item['model']}, Style: {item['style']}*") | |
| # Display metadata if present | |
| meta = item.get('metadata', {}) | |
| if any(meta.values()): | |
| meta_str = ", ".join([f"{k}: {v}" for k, v in meta.items() if v]) | |
| st.markdown(f"*Metadata: {meta_str}*") | |
| # Check for annotations for this prompt | |
| annotations = st.session_state[SESSION_STATE_ANNOTATIONS].get(prompt_key) | |
| if annotations: | |
| with st.expander(f"View {len(annotations)} Annotation(s)"): | |
| st.json(annotations) | |
| else: | |
| st.caption("_(No annotations made for this item yet)_") | |
| st.markdown("---") # Separator between history items | |
| # --- Export Annotations --- | |
| if st.session_state[SESSION_STATE_ANNOTATIONS]: | |
| st.markdown("---") | |
| st.subheader("β¬οΈ Export Annotations") | |
| try: | |
| # Prepare data with metadata included per annotation set | |
| export_data = {} | |
| # Find corresponding history item to enrich annotation export | |
| history_map = {item['prompt']: item for item in st.session_state[SESSION_STATE_HISTORY] if item.get('status') != 'failed'} | |
| for prompt, ann_objs in st.session_state[SESSION_STATE_ANNOTATIONS].items(): | |
| history_item = history_map.get(prompt) | |
| export_data[prompt] = { | |
| "annotations": ann_objs, | |
| "generation_details": { | |
| "model": history_item.get('model'), | |
| "style": history_item.get('style'), | |
| "strength": history_item.get('strength'), | |
| } if history_item else None, | |
| "metadata": history_item.get('metadata') if history_item else None | |
| } | |
| json_data = json.dumps(export_data, indent=2) | |
| st.download_button( | |
| label="β¬οΈ Export All Annotations (JSON)", | |
| data=json_data, | |
| file_name="medsketch_session_annotations.json", | |
| mime="application/json", | |
| help="Download all annotations made during this session, including associated metadata." | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to prepare annotations for download: {e}") | |
| logger.error(f"Error preparing JSON export: {e}") | |
| elif generate_button: # If generate was clicked but history is empty (e.g., all failed) | |
| st.info("No successful generations or annotations in the current session yet.") | |
| # Add a footer (optional) | |
| st.markdown("---") | |
| st.caption("MedSketch AI - Powered by Streamlit and OpenAI") |