|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import tempfile |
|
|
|
|
|
|
|
|
CACHE_DIR = "/tmp/hf_cache" |
|
|
STREAMLIT_DIR = "/tmp/streamlit" |
|
|
|
|
|
|
|
|
def safe_makedirs(path): |
|
|
try: |
|
|
os.makedirs(path, exist_ok=True) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not create {path}: {str(e)}") |
|
|
|
|
|
safe_makedirs(CACHE_DIR) |
|
|
safe_makedirs(STREAMLIT_DIR) |
|
|
|
|
|
|
|
|
os.environ.update({ |
|
|
"HOME": "/tmp", |
|
|
"XDG_CONFIG_HOME": "/tmp", |
|
|
"STREAMLIT_HOME": STREAMLIT_DIR, |
|
|
"XDG_CACHE_HOME": CACHE_DIR, |
|
|
"HF_HOME": f"{CACHE_DIR}/huggingface", |
|
|
"TRANSFORMERS_CACHE": f"{CACHE_DIR}/transformers", |
|
|
"HF_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub", |
|
|
"HUGGINGFACE_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub", |
|
|
"HF_HUB_DISABLE_TELEMETRY": "1", |
|
|
"STREAMLIT_SERVER_ENABLE_FILE_WATCHER": "false" |
|
|
}) |
|
|
|
|
|
|
|
|
for path in [ |
|
|
f"{CACHE_DIR}/huggingface", |
|
|
f"{CACHE_DIR}/transformers", |
|
|
f"{CACHE_DIR}/huggingface_hub", |
|
|
f"{STREAMLIT_DIR}/config" |
|
|
]: |
|
|
safe_makedirs(path) |
|
|
|
|
|
|
|
|
CONFIG_TOML = f"{STREAMLIT_DIR}/config/config.toml" |
|
|
if not os.path.exists(CONFIG_TOML): |
|
|
try: |
|
|
with open(CONFIG_TOML, "w") as f: |
|
|
f.write("[browser]\n") |
|
|
f.write("gatherUsageStats = false\n") |
|
|
f.write("[server]\n") |
|
|
f.write("fileWatcherType = none\n") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not create Streamlit config: {str(e)}") |
|
|
|
|
|
|
|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.transforms as T |
|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer |
|
|
from huggingface_hub import hf_hub_download, HfApi |
|
|
|
|
|
|
|
|
MODEL_ID = "RakeshNJ12345/Chest-Radiology" |
|
|
PROXY_URL = "https://hf-mirror.com" |
|
|
|
|
|
class TwoViewVisionReportModel(nn.Module): |
|
|
def __init__(self, vit: ViTModel, t5: T5ForConditionalGeneration, tokenizer: T5Tokenizer): |
|
|
super().__init__() |
|
|
self.vit = vit |
|
|
self.proj_f = nn.Linear(vit.config.hidden_size, t5.config.d_model) |
|
|
self.proj_l = nn.Linear(vit.config.hidden_size, t5.config.d_model) |
|
|
self.tokenizer = tokenizer |
|
|
self.t5 = t5 |
|
|
|
|
|
def generate(self, img: torch.Tensor, max_length: int = 128) -> torch.Tensor: |
|
|
device = img.device |
|
|
vf = self.vit(pixel_values=img).pooler_output |
|
|
pf = self.proj_f(vf).unsqueeze(1) |
|
|
prefix = pf |
|
|
|
|
|
enc = self.tokenizer("report:", return_tensors="pt").to(device) |
|
|
txt_emb = self.t5.encoder.embed_tokens(enc.input_ids) |
|
|
|
|
|
enc_emb = torch.cat([prefix, txt_emb], dim=1) |
|
|
enc_mask = torch.cat([ |
|
|
torch.ones(1, 1, device=device, dtype=torch.long), |
|
|
enc.attention_mask |
|
|
], dim=1) |
|
|
|
|
|
enc_out = self.t5.encoder( |
|
|
inputs_embeds=enc_emb, |
|
|
attention_mask=enc_mask |
|
|
) |
|
|
|
|
|
out_ids = self.t5.generate( |
|
|
encoder_outputs=enc_out, |
|
|
encoder_attention_mask=enc_mask, |
|
|
max_length=max_length, |
|
|
num_beams=1, |
|
|
do_sample=False, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
return out_ids |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def load_models(): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
for path in [ |
|
|
f"{CACHE_DIR}/huggingface", |
|
|
f"{CACHE_DIR}/transformers", |
|
|
f"{CACHE_DIR}/huggingface_hub" |
|
|
]: |
|
|
safe_makedirs(path) |
|
|
|
|
|
try: |
|
|
|
|
|
cfg_path = hf_hub_download( |
|
|
repo_id=MODEL_ID, |
|
|
filename="config.json", |
|
|
repo_type="model", |
|
|
cache_dir=f"{CACHE_DIR}/huggingface_hub", |
|
|
local_files_only=False |
|
|
) |
|
|
except Exception as e: |
|
|
st.error(f"β Failed to download model config: {str(e)}") |
|
|
st.info("β οΈ Trying alternative download method...") |
|
|
api = HfApi(endpoint=PROXY_URL) |
|
|
cfg_path = api.hf_hub_download( |
|
|
repo_id=MODEL_ID, |
|
|
filename="config.json", |
|
|
repo_type="model", |
|
|
cache_dir=f"{CACHE_DIR}/huggingface_hub", |
|
|
local_files_only=False |
|
|
) |
|
|
|
|
|
cfg = json.load(open(cfg_path, "r")) |
|
|
|
|
|
|
|
|
try: |
|
|
vit = ViTModel.from_pretrained( |
|
|
"google/vit-base-patch16-224", |
|
|
ignore_mismatched_sizes=True, |
|
|
cache_dir=f"{CACHE_DIR}/transformers" |
|
|
).to(device) |
|
|
except Exception as e: |
|
|
st.warning(f"β οΈ Standard ViT download failed: {str(e)}") |
|
|
vit = ViTModel.from_pretrained( |
|
|
"google/vit-base-patch16-224", |
|
|
ignore_mismatched_sizes=True, |
|
|
cache_dir=f"{CACHE_DIR}/transformers", |
|
|
mirror=PROXY_URL |
|
|
).to(device) |
|
|
|
|
|
try: |
|
|
t5 = T5ForConditionalGeneration.from_pretrained( |
|
|
"t5-base", |
|
|
cache_dir=f"{CACHE_DIR}/transformers" |
|
|
).to(device) |
|
|
except Exception as e: |
|
|
st.warning(f"β οΈ Standard T5 download failed: {str(e)}") |
|
|
t5 = T5ForConditionalGeneration.from_pretrained( |
|
|
"t5-base", |
|
|
cache_dir=f"{CACHE_DIR}/transformers", |
|
|
mirror=PROXY_URL |
|
|
).to(device) |
|
|
|
|
|
try: |
|
|
tok = T5Tokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
cache_dir=f"{CACHE_DIR}/transformers" |
|
|
) |
|
|
except Exception as e: |
|
|
st.warning(f"β οΈ Standard tokenizer download failed: {str(e)}") |
|
|
tok = T5Tokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
cache_dir=f"{CACHE_DIR}/transformers", |
|
|
mirror=PROXY_URL |
|
|
) |
|
|
|
|
|
|
|
|
model = TwoViewVisionReportModel(vit, t5, tok).to(device) |
|
|
|
|
|
try: |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id=MODEL_ID, |
|
|
filename="pytorch_model.bin", |
|
|
repo_type="model", |
|
|
cache_dir=f"{CACHE_DIR}/huggingface_hub", |
|
|
local_files_only=False |
|
|
) |
|
|
except Exception as e: |
|
|
st.warning(f"β οΈ Standard model weights download failed: {str(e)}") |
|
|
api = HfApi(endpoint=PROXY_URL) |
|
|
ckpt_path = api.hf_hub_download( |
|
|
repo_id=MODEL_ID, |
|
|
filename="pytorch_model.bin", |
|
|
repo_type="model", |
|
|
cache_dir=f"{CACHE_DIR}/huggingface_hub", |
|
|
local_files_only=False |
|
|
) |
|
|
|
|
|
state = torch.load(ckpt_path, map_location=device) |
|
|
model.load_state_dict(state) |
|
|
return device, model, tok |
|
|
|
|
|
|
|
|
try: |
|
|
device, model, tokenizer = load_models() |
|
|
except Exception as e: |
|
|
st.error(f"π¨ Critical Error: Failed to load models. {str(e)}") |
|
|
st.info("Please try refreshing the page or contact support@example.com") |
|
|
st.stop() |
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize((224, 224)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=0.5, std=0.5), |
|
|
]) |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Radiology Report Analysis", |
|
|
layout="wide", |
|
|
initial_sidebar_state="collapsed" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.reportview-container .main .block-container {padding-top: 2rem;} |
|
|
header {visibility: hidden;} |
|
|
.stDeployButton {display:none;} |
|
|
#MainMenu {visibility: hidden;} |
|
|
footer {visibility: hidden;} |
|
|
.stApp {background-color: #f8f9fa;} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("<h1 style='text-align:center; color:#2c3e50;'>π©Ί Radiology Report Analysis</h1>", unsafe_allow_html=True) |
|
|
st.markdown("<p style='text-align:center; color:#7f8c8d;'>Upload a chest X-ray and click Generate Report</p>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
if "img" not in st.session_state: |
|
|
uploaded = st.file_uploader("π€ Upload X-ray (PNG/JPG)", type=["png", "jpg", "jpeg"]) |
|
|
if uploaded: |
|
|
try: |
|
|
img = Image.open(uploaded).convert("RGB") |
|
|
|
|
|
img.thumbnail((10, 10)) |
|
|
st.session_state.img = uploaded |
|
|
st.experimental_rerun() |
|
|
except Exception as e: |
|
|
st.error(f"β Invalid image file: {str(e)}") |
|
|
st.stop() |
|
|
else: |
|
|
st.stop() |
|
|
|
|
|
img_file = st.session_state.img |
|
|
img = Image.open(img_file).convert("RGB") |
|
|
st.image(img, use_column_width=True, caption="Uploaded X-ray") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
if st.button("βΆοΈ Generate Report", use_container_width=True, type="primary", key="generate"): |
|
|
with st.spinner("Analyzing X-ray. This may take 10-20 seconds..."): |
|
|
try: |
|
|
px = transform(img).unsqueeze(0).to(device) |
|
|
out_ids = model.generate(px, max_length=128) |
|
|
report = tokenizer.decode(out_ids[0], skip_special_tokens=True) |
|
|
|
|
|
st.subheader("π AI-Generated Report") |
|
|
st.success(report) |
|
|
except Exception as e: |
|
|
st.error(f"β Analysis failed: {str(e)}") |
|
|
st.info("Please try with a different image or try again later") |
|
|
|
|
|
with col2: |
|
|
if st.button("β¬
οΈ Upload Another", use_container_width=True, key="upload_another"): |
|
|
del st.session_state.img |
|
|
st.experimental_rerun() |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown(""" |
|
|
**Note:** |
|
|
- First-time model loading may take 1-2 minutes |
|
|
- For optimal results, use clear chest X-ray images |
|
|
- Contact support@example.com for assistance |
|
|
""") |