File size: 10,440 Bytes
bce1c23 ea54d49 020994a cf005fe 020994a bce1c23 020994a fc90d26 020994a bce1c23 020994a ff5826b 020994a 255ca1e 020994a e7abd9d bce1c23 020994a bce1c23 f715625 bce1c23 020994a bce1c23 020994a bce1c23 020994a bce1c23 020994a bce1c23 020994a bce1c23 020994a bce1c23 020994a bce1c23 020994a cf005fe bce1c23 cf005fe 020994a aff81cd 020994a 2ca597d 020994a ba9ae1f 020994a bce1c23 020994a bce1c23 020994a e7abd9d 020994a e7abd9d cf005fe 020994a cf005fe ea54d49 cf005fe 020994a ea54d49 020994a bce1c23 020994a bce1c23 020994a bce1c23 020994a bce1c23 020994a bce1c23 e7abd9d 020994a e7abd9d 020994a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | # streamlit_app.py
# ββββ SET ENVIRONMENT VARIABLES BEFORE ANY IMPORTS ββββββββββββββββββββββββββββββ
import os
import tempfile
# Create dedicated cache directories
CACHE_DIR = "/tmp/hf_cache"
STREAMLIT_DIR = "/tmp/streamlit"
# Create directories safely without recursion
def safe_makedirs(path):
try:
os.makedirs(path, exist_ok=True)
except Exception as e:
print(f"Warning: Could not create {path}: {str(e)}")
safe_makedirs(CACHE_DIR)
safe_makedirs(STREAMLIT_DIR)
# Set all relevant environment variables
os.environ.update({
"HOME": "/tmp",
"XDG_CONFIG_HOME": "/tmp",
"STREAMLIT_HOME": STREAMLIT_DIR,
"XDG_CACHE_HOME": CACHE_DIR,
"HF_HOME": f"{CACHE_DIR}/huggingface",
"TRANSFORMERS_CACHE": f"{CACHE_DIR}/transformers",
"HF_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub",
"HUGGINGFACE_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub",
"HF_HUB_DISABLE_TELEMETRY": "1",
"STREAMLIT_SERVER_ENABLE_FILE_WATCHER": "false"
})
# Create all cache subdirectories
for path in [
f"{CACHE_DIR}/huggingface",
f"{CACHE_DIR}/transformers",
f"{CACHE_DIR}/huggingface_hub",
f"{STREAMLIT_DIR}/config"
]:
safe_makedirs(path)
# Create Streamlit config to disable usage stats
CONFIG_TOML = f"{STREAMLIT_DIR}/config/config.toml"
if not os.path.exists(CONFIG_TOML):
try:
with open(CONFIG_TOML, "w") as f:
f.write("[browser]\n")
f.write("gatherUsageStats = false\n")
f.write("[server]\n")
f.write("fileWatcherType = none\n")
except Exception as e:
print(f"Warning: Could not create Streamlit config: {str(e)}")
# ββββ NOW IMPORT OTHER LIBRARIES βββββββββββββββββββββββββββββββββββββββββββββββ
import json
import torch
import torch.nn as nn
import torchvision.transforms as T
import streamlit as st
from PIL import Image
from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
from huggingface_hub import hf_hub_download, HfApi
# ββββ MODEL DEFINITION βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_ID = "RakeshNJ12345/Chest-Radiology"
PROXY_URL = "https://hf-mirror.com" # Proxy for Hugging Face downloads
class TwoViewVisionReportModel(nn.Module):
def __init__(self, vit: ViTModel, t5: T5ForConditionalGeneration, tokenizer: T5Tokenizer):
super().__init__()
self.vit = vit
self.proj_f = nn.Linear(vit.config.hidden_size, t5.config.d_model)
self.proj_l = nn.Linear(vit.config.hidden_size, t5.config.d_model)
self.tokenizer = tokenizer
self.t5 = t5
def generate(self, img: torch.Tensor, max_length: int = 128) -> torch.Tensor:
device = img.device
vf = self.vit(pixel_values=img).pooler_output
pf = self.proj_f(vf).unsqueeze(1)
prefix = pf # single-view only
enc = self.tokenizer("report:", return_tensors="pt").to(device)
txt_emb = self.t5.encoder.embed_tokens(enc.input_ids)
enc_emb = torch.cat([prefix, txt_emb], dim=1)
enc_mask = torch.cat([
torch.ones(1, 1, device=device, dtype=torch.long),
enc.attention_mask
], dim=1)
enc_out = self.t5.encoder(
inputs_embeds=enc_emb,
attention_mask=enc_mask
)
out_ids = self.t5.generate(
encoder_outputs=enc_out,
encoder_attention_mask=enc_mask,
max_length=max_length,
num_beams=1,
do_sample=False,
eos_token_id=self.tokenizer.eos_token_id,
)
return out_ids
# ββββ MODEL LOADING WITH ERROR HANDLING ββββββββββββββββββββββββββββββββββββββββ
@st.cache_resource(show_spinner=False)
def load_models():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Ensure cache directories exist
for path in [
f"{CACHE_DIR}/huggingface",
f"{CACHE_DIR}/transformers",
f"{CACHE_DIR}/huggingface_hub"
]:
safe_makedirs(path)
try:
# Download config
cfg_path = hf_hub_download(
repo_id=MODEL_ID,
filename="config.json",
repo_type="model",
cache_dir=f"{CACHE_DIR}/huggingface_hub",
local_files_only=False
)
except Exception as e:
st.error(f"β Failed to download model config: {str(e)}")
st.info("β οΈ Trying alternative download method...")
api = HfApi(endpoint=PROXY_URL)
cfg_path = api.hf_hub_download(
repo_id=MODEL_ID,
filename="config.json",
repo_type="model",
cache_dir=f"{CACHE_DIR}/huggingface_hub",
local_files_only=False
)
cfg = json.load(open(cfg_path, "r"))
# Load models with explicit cache directories
try:
vit = ViTModel.from_pretrained(
"google/vit-base-patch16-224",
ignore_mismatched_sizes=True,
cache_dir=f"{CACHE_DIR}/transformers"
).to(device)
except Exception as e:
st.warning(f"β οΈ Standard ViT download failed: {str(e)}")
vit = ViTModel.from_pretrained(
"google/vit-base-patch16-224",
ignore_mismatched_sizes=True,
cache_dir=f"{CACHE_DIR}/transformers",
mirror=PROXY_URL
).to(device)
try:
t5 = T5ForConditionalGeneration.from_pretrained(
"t5-base",
cache_dir=f"{CACHE_DIR}/transformers"
).to(device)
except Exception as e:
st.warning(f"β οΈ Standard T5 download failed: {str(e)}")
t5 = T5ForConditionalGeneration.from_pretrained(
"t5-base",
cache_dir=f"{CACHE_DIR}/transformers",
mirror=PROXY_URL
).to(device)
try:
tok = T5Tokenizer.from_pretrained(
MODEL_ID,
cache_dir=f"{CACHE_DIR}/transformers"
)
except Exception as e:
st.warning(f"β οΈ Standard tokenizer download failed: {str(e)}")
tok = T5Tokenizer.from_pretrained(
MODEL_ID,
cache_dir=f"{CACHE_DIR}/transformers",
mirror=PROXY_URL
)
# Load combined model
model = TwoViewVisionReportModel(vit, t5, tok).to(device)
try:
ckpt_path = hf_hub_download(
repo_id=MODEL_ID,
filename="pytorch_model.bin",
repo_type="model",
cache_dir=f"{CACHE_DIR}/huggingface_hub",
local_files_only=False
)
except Exception as e:
st.warning(f"β οΈ Standard model weights download failed: {str(e)}")
api = HfApi(endpoint=PROXY_URL)
ckpt_path = api.hf_hub_download(
repo_id=MODEL_ID,
filename="pytorch_model.bin",
repo_type="model",
cache_dir=f"{CACHE_DIR}/huggingface_hub",
local_files_only=False
)
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state)
return device, model, tok
# ββββ APP INTERFACE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
try:
device, model, tokenizer = load_models()
except Exception as e:
st.error(f"π¨ Critical Error: Failed to load models. {str(e)}")
st.info("Please try refreshing the page or contact support@example.com")
st.stop()
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=0.5, std=0.5),
])
# Streamlit configuration
st.set_page_config(
page_title="Radiology Report Analysis",
layout="wide",
initial_sidebar_state="collapsed"
)
# Custom CSS
st.markdown("""
<style>
.reportview-container .main .block-container {padding-top: 2rem;}
header {visibility: hidden;}
.stDeployButton {display:none;}
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stApp {background-color: #f8f9fa;}
</style>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align:center; color:#2c3e50;'>π©Ί Radiology Report Analysis</h1>", unsafe_allow_html=True)
st.markdown("<p style='text-align:center; color:#7f8c8d;'>Upload a chest X-ray and click Generate Report</p>", unsafe_allow_html=True)
# File upload handling
if "img" not in st.session_state:
uploaded = st.file_uploader("π€ Upload X-ray (PNG/JPG)", type=["png", "jpg", "jpeg"])
if uploaded:
try:
img = Image.open(uploaded).convert("RGB")
# Quick verification by thumbnail generation
img.thumbnail((10, 10))
st.session_state.img = uploaded
st.experimental_rerun()
except Exception as e:
st.error(f"β Invalid image file: {str(e)}")
st.stop()
else:
st.stop()
img_file = st.session_state.img
img = Image.open(img_file).convert("RGB")
st.image(img, use_column_width=True, caption="Uploaded X-ray")
col1, col2 = st.columns(2)
with col1:
if st.button("βΆοΈ Generate Report", use_container_width=True, type="primary", key="generate"):
with st.spinner("Analyzing X-ray. This may take 10-20 seconds..."):
try:
px = transform(img).unsqueeze(0).to(device)
out_ids = model.generate(px, max_length=128)
report = tokenizer.decode(out_ids[0], skip_special_tokens=True)
st.subheader("π AI-Generated Report")
st.success(report)
except Exception as e:
st.error(f"β Analysis failed: {str(e)}")
st.info("Please try with a different image or try again later")
with col2:
if st.button("β¬
οΈ Upload Another", use_container_width=True, key="upload_another"):
del st.session_state.img
st.experimental_rerun()
# Add footer
st.markdown("---")
st.markdown("""
**Note:**
- First-time model loading may take 1-2 minutes
- For optimal results, use clear chest X-ray images
- Contact support@example.com for assistance
""") |