Automated_Chest-XRay_Report / src /streamlit_app.py
RakeshNJ12345's picture
Update src/streamlit_app.py
020994a verified
# streamlit_app.py
# ──── SET ENVIRONMENT VARIABLES BEFORE ANY IMPORTS ──────────────────────────────
import os
import tempfile
# Create dedicated cache directories
CACHE_DIR = "/tmp/hf_cache"
STREAMLIT_DIR = "/tmp/streamlit"
# Create directories safely without recursion
def safe_makedirs(path):
try:
os.makedirs(path, exist_ok=True)
except Exception as e:
print(f"Warning: Could not create {path}: {str(e)}")
safe_makedirs(CACHE_DIR)
safe_makedirs(STREAMLIT_DIR)
# Set all relevant environment variables
os.environ.update({
"HOME": "/tmp",
"XDG_CONFIG_HOME": "/tmp",
"STREAMLIT_HOME": STREAMLIT_DIR,
"XDG_CACHE_HOME": CACHE_DIR,
"HF_HOME": f"{CACHE_DIR}/huggingface",
"TRANSFORMERS_CACHE": f"{CACHE_DIR}/transformers",
"HF_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub",
"HUGGINGFACE_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub",
"HF_HUB_DISABLE_TELEMETRY": "1",
"STREAMLIT_SERVER_ENABLE_FILE_WATCHER": "false"
})
# Create all cache subdirectories
for path in [
f"{CACHE_DIR}/huggingface",
f"{CACHE_DIR}/transformers",
f"{CACHE_DIR}/huggingface_hub",
f"{STREAMLIT_DIR}/config"
]:
safe_makedirs(path)
# Create Streamlit config to disable usage stats
CONFIG_TOML = f"{STREAMLIT_DIR}/config/config.toml"
if not os.path.exists(CONFIG_TOML):
try:
with open(CONFIG_TOML, "w") as f:
f.write("[browser]\n")
f.write("gatherUsageStats = false\n")
f.write("[server]\n")
f.write("fileWatcherType = none\n")
except Exception as e:
print(f"Warning: Could not create Streamlit config: {str(e)}")
# ──── NOW IMPORT OTHER LIBRARIES ───────────────────────────────────────────────
import json
import torch
import torch.nn as nn
import torchvision.transforms as T
import streamlit as st
from PIL import Image
from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
from huggingface_hub import hf_hub_download, HfApi
# ──── MODEL DEFINITION ─────────────────────────────────────────────────────────
MODEL_ID = "RakeshNJ12345/Chest-Radiology"
PROXY_URL = "https://hf-mirror.com" # Proxy for Hugging Face downloads
class TwoViewVisionReportModel(nn.Module):
def __init__(self, vit: ViTModel, t5: T5ForConditionalGeneration, tokenizer: T5Tokenizer):
super().__init__()
self.vit = vit
self.proj_f = nn.Linear(vit.config.hidden_size, t5.config.d_model)
self.proj_l = nn.Linear(vit.config.hidden_size, t5.config.d_model)
self.tokenizer = tokenizer
self.t5 = t5
def generate(self, img: torch.Tensor, max_length: int = 128) -> torch.Tensor:
device = img.device
vf = self.vit(pixel_values=img).pooler_output
pf = self.proj_f(vf).unsqueeze(1)
prefix = pf # single-view only
enc = self.tokenizer("report:", return_tensors="pt").to(device)
txt_emb = self.t5.encoder.embed_tokens(enc.input_ids)
enc_emb = torch.cat([prefix, txt_emb], dim=1)
enc_mask = torch.cat([
torch.ones(1, 1, device=device, dtype=torch.long),
enc.attention_mask
], dim=1)
enc_out = self.t5.encoder(
inputs_embeds=enc_emb,
attention_mask=enc_mask
)
out_ids = self.t5.generate(
encoder_outputs=enc_out,
encoder_attention_mask=enc_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
eos_token_id=self.tokenizer.eos_token_id,
)
return out_ids
# ──── MODEL LOADING WITH ERROR HANDLING ────────────────────────────────────────
@st.cache_resource(show_spinner=False)
def load_models():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Ensure cache directories exist
for path in [
f"{CACHE_DIR}/huggingface",
f"{CACHE_DIR}/transformers",
f"{CACHE_DIR}/huggingface_hub"
]:
safe_makedirs(path)
try:
# Download config
cfg_path = hf_hub_download(
repo_id=MODEL_ID,
filename="config.json",
repo_type="model",
cache_dir=f"{CACHE_DIR}/huggingface_hub",
local_files_only=False
)
except Exception as e:
st.error(f"❌ Failed to download model config: {str(e)}")
st.info("⚠️ Trying alternative download method...")
api = HfApi(endpoint=PROXY_URL)
cfg_path = api.hf_hub_download(
repo_id=MODEL_ID,
filename="config.json",
repo_type="model",
cache_dir=f"{CACHE_DIR}/huggingface_hub",
local_files_only=False
)
cfg = json.load(open(cfg_path, "r"))
# Load models with explicit cache directories
try:
vit = ViTModel.from_pretrained(
"google/vit-base-patch16-224",
ignore_mismatched_sizes=True,
cache_dir=f"{CACHE_DIR}/transformers"
).to(device)
except Exception as e:
st.warning(f"⚠️ Standard ViT download failed: {str(e)}")
vit = ViTModel.from_pretrained(
"google/vit-base-patch16-224",
ignore_mismatched_sizes=True,
cache_dir=f"{CACHE_DIR}/transformers",
mirror=PROXY_URL
).to(device)
try:
t5 = T5ForConditionalGeneration.from_pretrained(
"t5-base",
cache_dir=f"{CACHE_DIR}/transformers"
).to(device)
except Exception as e:
st.warning(f"⚠️ Standard T5 download failed: {str(e)}")
t5 = T5ForConditionalGeneration.from_pretrained(
"t5-base",
cache_dir=f"{CACHE_DIR}/transformers",
mirror=PROXY_URL
).to(device)
try:
tok = T5Tokenizer.from_pretrained(
MODEL_ID,
cache_dir=f"{CACHE_DIR}/transformers"
)
except Exception as e:
st.warning(f"⚠️ Standard tokenizer download failed: {str(e)}")
tok = T5Tokenizer.from_pretrained(
MODEL_ID,
cache_dir=f"{CACHE_DIR}/transformers",
mirror=PROXY_URL
)
# Load combined model
model = TwoViewVisionReportModel(vit, t5, tok).to(device)
try:
ckpt_path = hf_hub_download(
repo_id=MODEL_ID,
filename="pytorch_model.bin",
repo_type="model",
cache_dir=f"{CACHE_DIR}/huggingface_hub",
local_files_only=False
)
except Exception as e:
st.warning(f"⚠️ Standard model weights download failed: {str(e)}")
api = HfApi(endpoint=PROXY_URL)
ckpt_path = api.hf_hub_download(
repo_id=MODEL_ID,
filename="pytorch_model.bin",
repo_type="model",
cache_dir=f"{CACHE_DIR}/huggingface_hub",
local_files_only=False
)
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state)
return device, model, tok
# ──── APP INTERFACE ───────────────────────────────────────────────────────────
try:
device, model, tokenizer = load_models()
except Exception as e:
st.error(f"🚨 Critical Error: Failed to load models. {str(e)}")
st.info("Please try refreshing the page or contact support@example.com")
st.stop()
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=0.5, std=0.5),
])
# Streamlit configuration
st.set_page_config(
page_title="Radiology Report Analysis",
layout="wide",
initial_sidebar_state="collapsed"
)
# Custom CSS
st.markdown("""
<style>
.reportview-container .main .block-container {padding-top: 2rem;}
header {visibility: hidden;}
.stDeployButton {display:none;}
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stApp {background-color: #f8f9fa;}
</style>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align:center; color:#2c3e50;'>🩺 Radiology Report Analysis</h1>", unsafe_allow_html=True)
st.markdown("<p style='text-align:center; color:#7f8c8d;'>Upload a chest X-ray and click Generate Report</p>", unsafe_allow_html=True)
# File upload handling
if "img" not in st.session_state:
uploaded = st.file_uploader("πŸ“€ Upload X-ray (PNG/JPG)", type=["png", "jpg", "jpeg"])
if uploaded:
try:
img = Image.open(uploaded).convert("RGB")
# Quick verification by thumbnail generation
img.thumbnail((10, 10))
st.session_state.img = uploaded
st.experimental_rerun()
except Exception as e:
st.error(f"❌ Invalid image file: {str(e)}")
st.stop()
else:
st.stop()
img_file = st.session_state.img
img = Image.open(img_file).convert("RGB")
st.image(img, use_column_width=True, caption="Uploaded X-ray")
col1, col2 = st.columns(2)
with col1:
if st.button("▢️ Generate Report", use_container_width=True, type="primary", key="generate"):
with st.spinner("Analyzing X-ray. This may take 10-20 seconds..."):
try:
px = transform(img).unsqueeze(0).to(device)
out_ids = model.generate(px, max_length=128)
report = tokenizer.decode(out_ids[0], skip_special_tokens=True)
st.subheader("πŸ“ AI-Generated Report")
st.success(report)
except Exception as e:
st.error(f"❌ Analysis failed: {str(e)}")
st.info("Please try with a different image or try again later")
with col2:
if st.button("⬅️ Upload Another", use_container_width=True, key="upload_another"):
del st.session_state.img
st.experimental_rerun()
# Add footer
st.markdown("---")
st.markdown("""
**Note:**
- First-time model loading may take 1-2 minutes
- For optimal results, use clear chest X-ray images
- Contact support@example.com for assistance
""")