| | |
| |
|
| | |
| | import os |
| | for k,v in [ |
| | ("HOME", "/tmp"), |
| | ("XDG_CONFIG_HOME", "/tmp"), |
| | ("STREAMLIT_HOME", "/tmp"), |
| | ("XDG_CACHE_HOME", "/tmp"), |
| | ("HF_HOME", "/tmp/hf"), |
| | ("TRANSFORMERS_CACHE", "/tmp/hf/transformers"), |
| | ]: |
| | os.environ[k] = v |
| |
|
| | |
| | for d in ["/tmp/streamlit", "/tmp/hf/transformers"]: |
| | os.makedirs(d, exist_ok=True) |
| |
|
| |
|
| | |
| | import streamlit as st |
| | from PIL import Image |
| | import torch |
| | import torch.nn as nn |
| | import torchvision.transforms as T |
| | from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer |
| |
|
| |
|
| | |
| | MODEL_ID = "RakeshNJ12345/Automated-Chest-XRay-Report" |
| |
|
| |
|
| | @st.cache_resource(show_spinner="Loading fine-tuned ViT & T5β¦") |
| | def load_models(): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | vit = ViTModel.from_pretrained( |
| | MODEL_ID, |
| | subfolder="vit", |
| | ignore_mismatched_sizes=True |
| | ).to(device) |
| |
|
| | |
| | t5 = T5ForConditionalGeneration.from_pretrained( |
| | MODEL_ID, subfolder="T5" |
| | ).to(device) |
| | tok = T5Tokenizer.from_pretrained( |
| | MODEL_ID, subfolder="T5" |
| | ) |
| |
|
| | |
| | proj = nn.Linear(vit.config.hidden_size, t5.config.d_model).to(device) |
| |
|
| | return device, vit, proj, t5, tok |
| |
|
| |
|
| | |
| | device, vit, proj, t5, tokenizer = load_models() |
| |
|
| |
|
| | |
| | transform = T.Compose([ |
| | T.Resize((224, 224)), |
| | T.ToTensor(), |
| | T.Normalize(mean=0.5, std=0.5), |
| | ]) |
| |
|
| |
|
| | |
| | st.set_page_config(page_title="Radiology Report Analysis", layout="wide") |
| |
|
| | st.markdown( |
| | "<h1 style='text-align:center;'>π©Ί Radiology Report Analysis</h1>", |
| | unsafe_allow_html=True |
| | ) |
| | st.markdown( |
| | "<p style='text-align:center;'>Upload a chest X-ray and click <b>Generate Report</b>.</p>", |
| | unsafe_allow_html=True |
| | ) |
| |
|
| | if "stage" not in st.session_state: |
| | st.session_state.stage = "upload" |
| |
|
| |
|
| | |
| | if st.session_state.stage == "upload": |
| | uploaded = st.file_uploader( |
| | "π€ Upload Chest X-ray (PNG/JPG)", |
| | type=["png","jpg","jpeg"], |
| | label_visibility="visible" |
| | ) |
| | if uploaded: |
| | st.image(uploaded, width=350, |
| | caption=f"{uploaded.name} β {uploaded.size/1e6:.2f} MB") |
| | if st.button("βΆοΈ Generate Report"): |
| | st.session_state.uploaded = uploaded |
| | st.session_state.stage = "report" |
| | st.experimental_rerun() |
| |
|
| |
|
| | |
| | elif st.session_state.stage == "report": |
| | img_file = st.session_state.uploaded |
| | img = Image.open(img_file).convert("RGB") |
| |
|
| | with st.spinner("π Analyzingβ¦"): |
| | |
| | x = transform(img).unsqueeze(0).to(device) |
| | vit_out = vit(pixel_values=x).pooler_output |
| |
|
| | |
| | vision_pref = proj(vit_out).unsqueeze(1) |
| |
|
| | |
| | enc = tokenizer("report:", return_tensors="pt").to(device) |
| | txt_emb = t5.encoder.embed_tokens(enc.input_ids) |
| |
|
| | |
| | enc_emb = torch.cat([vision_pref, txt_emb], dim=1) |
| | enc_mask = torch.cat([ |
| | torch.ones(1,1,device=device,dtype=torch.long), |
| | enc.attention_mask |
| | ], dim=1) |
| |
|
| | |
| | enc_out = t5.encoder( |
| | inputs_embeds = enc_emb, |
| | attention_mask = enc_mask |
| | ) |
| |
|
| | |
| | out_ids = t5.generate( |
| | encoder_outputs=enc_out, |
| | max_length = 64, |
| | num_beams = 2, |
| | do_sample = False, |
| | eos_token_id = tokenizer.eos_token_id, |
| | ) |
| | report = tokenizer.decode(out_ids[0], skip_special_tokens=True) |
| |
|
| | |
| | c1, c2 = st.columns(2) |
| | with c1: |
| | st.subheader("π€ Uploaded X-ray") |
| | st.image(img, use_container_width=True) |
| | st.markdown(f"**Filename:** {img_file.name}") |
| | st.markdown(f"**Size:** {img_file.size/1e6:.2f} MB") |
| | with c2: |
| | st.subheader("π AI Diagnosis & Report") |
| | st.markdown( |
| | f"<div style='background:#e0f7fa;padding:12px;border-radius:6px;'>{report}</div>", |
| | unsafe_allow_html=True |
| | ) |
| | if st.button("β¬
οΈ Upload Another"): |
| | del st.session_state.uploaded |
| | st.session_state.stage = "upload" |
| | st.experimental_rerun() |
| |
|
| |
|
| | |
| | st.markdown(""" |
| | <hr> |
| | <p style='text-align:center;color:gray;font-size:0.85em;'> |
| | βοΈ Powered by fine-tuned ViT + T5 Β· Built with Streamlit Β· Hosted on HF Spaces |
| | </p> |
| | """, unsafe_allow_html=True) |
| |
|