# src/streamlit_app.py # ─── 1) Force all HF/Streamlit temp dirs into /tmp ───────────────────────────── import os for k,v in [ ("HOME", "/tmp"), ("XDG_CONFIG_HOME", "/tmp"), ("STREAMLIT_HOME", "/tmp"), ("XDG_CACHE_HOME", "/tmp"), ("HF_HOME", "/tmp/hf"), ("TRANSFORMERS_CACHE", "/tmp/hf/transformers"), ]: os.environ[k] = v # create those dirs for d in ["/tmp/streamlit", "/tmp/hf/transformers"]: os.makedirs(d, exist_ok=True) # ─── 2) Imports ──────────────────────────────────────────────────────────────── import streamlit as st from PIL import Image import torch import torch.nn as nn import torchvision.transforms as T from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer # ─── 3) Your model repo & subfolders ─────────────────────────────────────────── MODEL_ID = "RakeshNJ12345/Automated-Chest-XRay-Report" @st.cache_resource(show_spinner="Loading fine-tuned ViT & T5…") def load_models(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1) load fine-tuned ViT from vit/ vit = ViTModel.from_pretrained( MODEL_ID, subfolder="vit", ignore_mismatched_sizes=True ).to(device) # 2) load fine-tuned T5 + tokenizer from T5/ t5 = T5ForConditionalGeneration.from_pretrained( MODEL_ID, subfolder="T5" ).to(device) tok = T5Tokenizer.from_pretrained( MODEL_ID, subfolder="T5" ) # 3) build a single Linear projection so we reuse it each call proj = nn.Linear(vit.config.hidden_size, t5.config.d_model).to(device) return device, vit, proj, t5, tok # ─── 4) Kick off model loading ───────────────────────────────────────────────── device, vit, proj, t5, tokenizer = load_models() # ─── 5) Image preprocessing ──────────────────────────────────────────────────── transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize(mean=0.5, std=0.5), ]) # ─── 6) Streamlit UI ─────────────────────────────────────────────────────────── st.set_page_config(page_title="Radiology Report Analysis", layout="wide") st.markdown( "

🩺 Radiology Report Analysis

", unsafe_allow_html=True ) st.markdown( "

Upload a chest X-ray and click Generate Report.

", unsafe_allow_html=True ) if "stage" not in st.session_state: st.session_state.stage = "upload" # ─── 7) UPLOAD SCREEN ────────────────────────────────────────────────────────── if st.session_state.stage == "upload": uploaded = st.file_uploader( "📤 Upload Chest X-ray (PNG/JPG)", type=["png","jpg","jpeg"], label_visibility="visible" ) if uploaded: st.image(uploaded, width=350, caption=f"{uploaded.name} — {uploaded.size/1e6:.2f} MB") if st.button("▶️ Generate Report"): st.session_state.uploaded = uploaded st.session_state.stage = "report" st.experimental_rerun() # ─── 8) REPORT SCREEN ────────────────────────────────────────────────────────── elif st.session_state.stage == "report": img_file = st.session_state.uploaded img = Image.open(img_file).convert("RGB") with st.spinner("🔍 Analyzing…"): # 1) ViT features x = transform(img).unsqueeze(0).to(device) # [1,3,224,224] vit_out = vit(pixel_values=x).pooler_output # [1,768] # 2) project → prefix vision_pref = proj(vit_out).unsqueeze(1) # [1,1,d_model] # 3) prepare prompt tokens enc = tokenizer("report:", return_tensors="pt").to(device) txt_emb = t5.encoder.embed_tokens(enc.input_ids) # [1,L,d_model] # 4) concat embeddings + mask enc_emb = torch.cat([vision_pref, txt_emb], dim=1) # [1,1+L,d] enc_mask = torch.cat([ torch.ones(1,1,device=device,dtype=torch.long), enc.attention_mask ], dim=1) # [1,1+L] # 5) feed into T5 encoder enc_out = t5.encoder( inputs_embeds = enc_emb, attention_mask = enc_mask ) # 6) generate **without** encoder_attention_mask kwarg out_ids = t5.generate( encoder_outputs=enc_out, max_length = 64, num_beams = 2, do_sample = False, eos_token_id = tokenizer.eos_token_id, ) report = tokenizer.decode(out_ids[0], skip_special_tokens=True) # ── display side by side ──────────────────────────────────────────────── c1, c2 = st.columns(2) with c1: st.subheader("📤 Uploaded X-ray") st.image(img, use_container_width=True) st.markdown(f"**Filename:** {img_file.name}") st.markdown(f"**Size:** {img_file.size/1e6:.2f} MB") with c2: st.subheader("📝 AI Diagnosis & Report") st.markdown( f"
{report}
", unsafe_allow_html=True ) if st.button("⬅️ Upload Another"): del st.session_state.uploaded st.session_state.stage = "upload" st.experimental_rerun() # ─── Footer ──────────────────────────────────────────────────────────────────── st.markdown("""

⚙️ Powered by fine-tuned ViT + T5 · Built with Streamlit · Hosted on HF Spaces

""", unsafe_allow_html=True)