Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import subprocess | |
| import sys | |
| import importlib | |
| from medigan import Generators | |
| def install_dependencies(model_id): | |
| """Install dependencies for a specific model""" | |
| try: | |
| result = subprocess.run( | |
| f"python -m src.medigan.install_model_dependencies --model_id {model_id}", | |
| shell=True, | |
| check=True, | |
| capture_output=True | |
| ) | |
| st.success(f"Dependencies installed for {model_id}") | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| st.error(f"Failed to install dependencies: {e.stderr.decode()}") | |
| return False | |
| def check_and_install(package): | |
| """Check if a package is installed, install if missing""" | |
| try: | |
| importlib.import_module(package.split("[")[0]) # Handle extras like scikit-image[optional] | |
| except ImportError: | |
| st.warning(f"Installing missing dependency: {package}") | |
| subprocess.run([sys.executable, "-m", "pip", "install", package], check=True) | |
| st.success(f"Successfully installed {package}") | |
| # Common dependencies for all models | |
| COMMON_DEPENDENCIES = [ | |
| "numpy", | |
| "pyyaml", | |
| "opencv-contrib-python-headless", | |
| "torch", | |
| "torchvision", | |
| "dominate", | |
| "visdom", | |
| "Pillow", | |
| "imageio", | |
| "scikit-image", | |
| "pathlib" # For 'Path' dependency | |
| ] | |
| def main(): | |
| st.set_page_config(page_title="MEDIGAN Generator", layout="wide") | |
| st.title("🧠 Medical Image Generator") | |
| # Pre-install common dependencies | |
| with st.spinner("Checking system dependencies..."): | |
| for dep in COMMON_DEPENDENCIES: | |
| check_and_install(dep) | |
| MODEL_IDS = [ | |
| "00001_DCGAN_MMG_CALC_ROI", | |
| "00002_DCGAN_MMG_MASS_ROI", | |
| "00003_CYCLEGAN_MMG_DENSITY_FULL", | |
| "00004_PIX2PIX_MMG_MASSES_W_MASKS", | |
| "00019_PGGAN_CHEST_XRAY" | |
| ] | |
| with st.sidebar: | |
| st.header("⚙️ Settings") | |
| model_id = st.selectbox("Select Model", MODEL_IDS) | |
| num_images = st.slider("Number of Images", 1, 8, 4) | |
| if st.button("✨ Generate Images"): | |
| with st.spinner("Initializing model..."): | |
| if not install_dependencies(model_id): | |
| return | |
| generate_images(num_images, model_id) | |
| def generate_images(num_images, model_id): | |
| try: | |
| generators = Generators() | |
| images = [] | |
| for i in range(num_images): | |
| sample = generators.generate( | |
| model_id=model_id, | |
| num_samples=1, | |
| install_dependencies=False # Already handled by our install_dependencies() | |
| ) | |
| img = to_pil_image(sample[0]).convert("RGB") | |
| images.append(img) | |
| cols = st.columns(4) | |
| for idx, img in enumerate(images): | |
| with cols[idx % 4]: | |
| st.image(img, caption=f"Image {idx+1}", use_column_width=True) | |
| st.markdown("---") | |
| except Exception as e: | |
| st.error(f"Generation failed: {str(e)}") | |
| st.info("If the error persists, try manual installation:") | |
| st.code(f"pip install {' '.join(COMMON_DEPENDENCIES)}") | |
| if __name__ == "__main__": | |
| main() |