Rakesh-Radiology-Report-App / src /streamlit_app.py
RakeshNJ12345's picture
Update src/streamlit_app.py
01dd584 verified
# src/streamlit_app.py
# ─── 1) Force all HF/Streamlit temp dirs into /tmp ─────────────────────────────
import os
for k,v in [
("HOME", "/tmp"),
("XDG_CONFIG_HOME", "/tmp"),
("STREAMLIT_HOME", "/tmp"),
("XDG_CACHE_HOME", "/tmp"),
("HF_HOME", "/tmp/hf"),
("TRANSFORMERS_CACHE", "/tmp/hf/transformers"),
]:
os.environ[k] = v
# create those dirs
for d in ["/tmp/streamlit", "/tmp/hf/transformers"]:
os.makedirs(d, exist_ok=True)
# ─── 2) Imports ────────────────────────────────────────────────────────────────
import streamlit as st
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
# ─── 3) Your model repo & subfolders ───────────────────────────────────────────
MODEL_ID = "RakeshNJ12345/Automated-Chest-XRay-Report"
@st.cache_resource(show_spinner="Loading fine-tuned ViT & T5…")
def load_models():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1) load fine-tuned ViT from vit/
vit = ViTModel.from_pretrained(
MODEL_ID,
subfolder="vit",
ignore_mismatched_sizes=True
).to(device)
# 2) load fine-tuned T5 + tokenizer from T5/
t5 = T5ForConditionalGeneration.from_pretrained(
MODEL_ID, subfolder="T5"
).to(device)
tok = T5Tokenizer.from_pretrained(
MODEL_ID, subfolder="T5"
)
# 3) build a single Linear projection so we reuse it each call
proj = nn.Linear(vit.config.hidden_size, t5.config.d_model).to(device)
return device, vit, proj, t5, tok
# ─── 4) Kick off model loading ─────────────────────────────────────────────────
device, vit, proj, t5, tokenizer = load_models()
# ─── 5) Image preprocessing ────────────────────────────────────────────────────
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=0.5, std=0.5),
])
# ─── 6) Streamlit UI ───────────────────────────────────────────────────────────
st.set_page_config(page_title="Radiology Report Analysis", layout="wide")
st.markdown(
"<h1 style='text-align:center;'>🩺 Radiology Report Analysis</h1>",
unsafe_allow_html=True
)
st.markdown(
"<p style='text-align:center;'>Upload a chest X-ray and click <b>Generate Report</b>.</p>",
unsafe_allow_html=True
)
if "stage" not in st.session_state:
st.session_state.stage = "upload"
# ─── 7) UPLOAD SCREEN ──────────────────────────────────────────────────────────
if st.session_state.stage == "upload":
uploaded = st.file_uploader(
"πŸ“€ Upload Chest X-ray (PNG/JPG)",
type=["png","jpg","jpeg"],
label_visibility="visible"
)
if uploaded:
st.image(uploaded, width=350,
caption=f"{uploaded.name} β€” {uploaded.size/1e6:.2f} MB")
if st.button("▢️ Generate Report"):
st.session_state.uploaded = uploaded
st.session_state.stage = "report"
st.experimental_rerun()
# ─── 8) REPORT SCREEN ──────────────────────────────────────────────────────────
elif st.session_state.stage == "report":
img_file = st.session_state.uploaded
img = Image.open(img_file).convert("RGB")
with st.spinner("πŸ” Analyzing…"):
# 1) ViT features
x = transform(img).unsqueeze(0).to(device) # [1,3,224,224]
vit_out = vit(pixel_values=x).pooler_output # [1,768]
# 2) project β†’ prefix
vision_pref = proj(vit_out).unsqueeze(1) # [1,1,d_model]
# 3) prepare prompt tokens
enc = tokenizer("report:", return_tensors="pt").to(device)
txt_emb = t5.encoder.embed_tokens(enc.input_ids) # [1,L,d_model]
# 4) concat embeddings + mask
enc_emb = torch.cat([vision_pref, txt_emb], dim=1) # [1,1+L,d]
enc_mask = torch.cat([
torch.ones(1,1,device=device,dtype=torch.long),
enc.attention_mask
], dim=1) # [1,1+L]
# 5) feed into T5 encoder
enc_out = t5.encoder(
inputs_embeds = enc_emb,
attention_mask = enc_mask
)
# 6) generate **without** encoder_attention_mask kwarg
out_ids = t5.generate(
encoder_outputs=enc_out,
max_length = 64,
num_beams = 2,
do_sample = False,
eos_token_id = tokenizer.eos_token_id,
)
report = tokenizer.decode(out_ids[0], skip_special_tokens=True)
# ── display side by side ────────────────────────────────────────────────
c1, c2 = st.columns(2)
with c1:
st.subheader("πŸ“€ Uploaded X-ray")
st.image(img, use_container_width=True)
st.markdown(f"**Filename:** {img_file.name}")
st.markdown(f"**Size:** {img_file.size/1e6:.2f} MB")
with c2:
st.subheader("πŸ“ AI Diagnosis & Report")
st.markdown(
f"<div style='background:#e0f7fa;padding:12px;border-radius:6px;'>{report}</div>",
unsafe_allow_html=True
)
if st.button("⬅️ Upload Another"):
del st.session_state.uploaded
st.session_state.stage = "upload"
st.experimental_rerun()
# ─── Footer ────────────────────────────────────────────────────────────────────
st.markdown("""
<hr>
<p style='text-align:center;color:gray;font-size:0.85em;'>
βš™οΈ Powered by fine-tuned ViT + T5 Β· Built with Streamlit Β· Hosted on HF Spaces
</p>
""", unsafe_allow_html=True)