# streamlit_app.py # ──── SET ENVIRONMENT VARIABLES BEFORE ANY IMPORTS ────────────────────────────── import os import tempfile # Create dedicated cache directories CACHE_DIR = "/tmp/hf_cache" STREAMLIT_DIR = "/tmp/streamlit" # Create directories safely without recursion def safe_makedirs(path): try: os.makedirs(path, exist_ok=True) except Exception as e: print(f"Warning: Could not create {path}: {str(e)}") safe_makedirs(CACHE_DIR) safe_makedirs(STREAMLIT_DIR) # Set all relevant environment variables os.environ.update({ "HOME": "/tmp", "XDG_CONFIG_HOME": "/tmp", "STREAMLIT_HOME": STREAMLIT_DIR, "XDG_CACHE_HOME": CACHE_DIR, "HF_HOME": f"{CACHE_DIR}/huggingface", "TRANSFORMERS_CACHE": f"{CACHE_DIR}/transformers", "HF_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub", "HUGGINGFACE_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub", "HF_HUB_DISABLE_TELEMETRY": "1", "STREAMLIT_SERVER_ENABLE_FILE_WATCHER": "false" }) # Create all cache subdirectories for path in [ f"{CACHE_DIR}/huggingface", f"{CACHE_DIR}/transformers", f"{CACHE_DIR}/huggingface_hub", f"{STREAMLIT_DIR}/config" ]: safe_makedirs(path) # Create Streamlit config to disable usage stats CONFIG_TOML = f"{STREAMLIT_DIR}/config/config.toml" if not os.path.exists(CONFIG_TOML): try: with open(CONFIG_TOML, "w") as f: f.write("[browser]\n") f.write("gatherUsageStats = false\n") f.write("[server]\n") f.write("fileWatcherType = none\n") except Exception as e: print(f"Warning: Could not create Streamlit config: {str(e)}") # ──── NOW IMPORT OTHER LIBRARIES ─────────────────────────────────────────────── import json import torch import torch.nn as nn import torchvision.transforms as T import streamlit as st from PIL import Image from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer from huggingface_hub import hf_hub_download, HfApi # ──── MODEL DEFINITION ───────────────────────────────────────────────────────── MODEL_ID = "RakeshNJ12345/Chest-Radiology" PROXY_URL = "https://hf-mirror.com" # Proxy for Hugging Face downloads class TwoViewVisionReportModel(nn.Module): def __init__(self, vit: ViTModel, t5: T5ForConditionalGeneration, tokenizer: T5Tokenizer): super().__init__() self.vit = vit self.proj_f = nn.Linear(vit.config.hidden_size, t5.config.d_model) self.proj_l = nn.Linear(vit.config.hidden_size, t5.config.d_model) self.tokenizer = tokenizer self.t5 = t5 def generate(self, img: torch.Tensor, max_length: int = 128) -> torch.Tensor: device = img.device vf = self.vit(pixel_values=img).pooler_output pf = self.proj_f(vf).unsqueeze(1) prefix = pf # single-view only enc = self.tokenizer("report:", return_tensors="pt").to(device) txt_emb = self.t5.encoder.embed_tokens(enc.input_ids) enc_emb = torch.cat([prefix, txt_emb], dim=1) enc_mask = torch.cat([ torch.ones(1, 1, device=device, dtype=torch.long), enc.attention_mask ], dim=1) enc_out = self.t5.encoder( inputs_embeds=enc_emb, attention_mask=enc_mask ) out_ids = self.t5.generate( encoder_outputs=enc_out, encoder_attention_mask=enc_mask, max_length=max_length, num_beams=1, do_sample=False, eos_token_id=self.tokenizer.eos_token_id, ) return out_ids # ──── MODEL LOADING WITH ERROR HANDLING ──────────────────────────────────────── @st.cache_resource(show_spinner=False) def load_models(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Ensure cache directories exist for path in [ f"{CACHE_DIR}/huggingface", f"{CACHE_DIR}/transformers", f"{CACHE_DIR}/huggingface_hub" ]: safe_makedirs(path) try: # Download config cfg_path = hf_hub_download( repo_id=MODEL_ID, filename="config.json", repo_type="model", cache_dir=f"{CACHE_DIR}/huggingface_hub", local_files_only=False ) except Exception as e: st.error(f"❌ Failed to download model config: {str(e)}") st.info("⚠️ Trying alternative download method...") api = HfApi(endpoint=PROXY_URL) cfg_path = api.hf_hub_download( repo_id=MODEL_ID, filename="config.json", repo_type="model", cache_dir=f"{CACHE_DIR}/huggingface_hub", local_files_only=False ) cfg = json.load(open(cfg_path, "r")) # Load models with explicit cache directories try: vit = ViTModel.from_pretrained( "google/vit-base-patch16-224", ignore_mismatched_sizes=True, cache_dir=f"{CACHE_DIR}/transformers" ).to(device) except Exception as e: st.warning(f"⚠️ Standard ViT download failed: {str(e)}") vit = ViTModel.from_pretrained( "google/vit-base-patch16-224", ignore_mismatched_sizes=True, cache_dir=f"{CACHE_DIR}/transformers", mirror=PROXY_URL ).to(device) try: t5 = T5ForConditionalGeneration.from_pretrained( "t5-base", cache_dir=f"{CACHE_DIR}/transformers" ).to(device) except Exception as e: st.warning(f"⚠️ Standard T5 download failed: {str(e)}") t5 = T5ForConditionalGeneration.from_pretrained( "t5-base", cache_dir=f"{CACHE_DIR}/transformers", mirror=PROXY_URL ).to(device) try: tok = T5Tokenizer.from_pretrained( MODEL_ID, cache_dir=f"{CACHE_DIR}/transformers" ) except Exception as e: st.warning(f"⚠️ Standard tokenizer download failed: {str(e)}") tok = T5Tokenizer.from_pretrained( MODEL_ID, cache_dir=f"{CACHE_DIR}/transformers", mirror=PROXY_URL ) # Load combined model model = TwoViewVisionReportModel(vit, t5, tok).to(device) try: ckpt_path = hf_hub_download( repo_id=MODEL_ID, filename="pytorch_model.bin", repo_type="model", cache_dir=f"{CACHE_DIR}/huggingface_hub", local_files_only=False ) except Exception as e: st.warning(f"⚠️ Standard model weights download failed: {str(e)}") api = HfApi(endpoint=PROXY_URL) ckpt_path = api.hf_hub_download( repo_id=MODEL_ID, filename="pytorch_model.bin", repo_type="model", cache_dir=f"{CACHE_DIR}/huggingface_hub", local_files_only=False ) state = torch.load(ckpt_path, map_location=device) model.load_state_dict(state) return device, model, tok # ──── APP INTERFACE ─────────────────────────────────────────────────────────── try: device, model, tokenizer = load_models() except Exception as e: st.error(f"🚨 Critical Error: Failed to load models. {str(e)}") st.info("Please try refreshing the page or contact support@example.com") st.stop() transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=0.5, std=0.5), ]) # Streamlit configuration st.set_page_config( page_title="Radiology Report Analysis", layout="wide", initial_sidebar_state="collapsed" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) st.markdown("

🩺 Radiology Report Analysis

", unsafe_allow_html=True) st.markdown("

Upload a chest X-ray and click Generate Report

", unsafe_allow_html=True) # File upload handling if "img" not in st.session_state: uploaded = st.file_uploader("📤 Upload X-ray (PNG/JPG)", type=["png", "jpg", "jpeg"]) if uploaded: try: img = Image.open(uploaded).convert("RGB") # Quick verification by thumbnail generation img.thumbnail((10, 10)) st.session_state.img = uploaded st.experimental_rerun() except Exception as e: st.error(f"❌ Invalid image file: {str(e)}") st.stop() else: st.stop() img_file = st.session_state.img img = Image.open(img_file).convert("RGB") st.image(img, use_column_width=True, caption="Uploaded X-ray") col1, col2 = st.columns(2) with col1: if st.button("▶️ Generate Report", use_container_width=True, type="primary", key="generate"): with st.spinner("Analyzing X-ray. This may take 10-20 seconds..."): try: px = transform(img).unsqueeze(0).to(device) out_ids = model.generate(px, max_length=128) report = tokenizer.decode(out_ids[0], skip_special_tokens=True) st.subheader("📝 AI-Generated Report") st.success(report) except Exception as e: st.error(f"❌ Analysis failed: {str(e)}") st.info("Please try with a different image or try again later") with col2: if st.button("⬅️ Upload Another", use_container_width=True, key="upload_another"): del st.session_state.img st.experimental_rerun() # Add footer st.markdown("---") st.markdown(""" **Note:** - First-time model loading may take 1-2 minutes - For optimal results, use clear chest X-ray images - Contact support@example.com for assistance """)