RakeshNJ12345 commited on
Commit
bce1c23
Β·
verified Β·
1 Parent(s): ffe9e42

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +159 -103
src/streamlit_app.py CHANGED
@@ -1,124 +1,180 @@
1
- # app.py or streamlit_app.py
2
 
 
3
  import os
4
- import streamlit as st
5
- from PIL import Image
6
- import torch
7
- import torchvision.transforms as T
8
- import pydicom
9
- import numpy as np
10
- from transformers import ViTFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel
11
-
12
- # ─── FORCE ALL CACHE & CONFIG INTO /tmp ────────────────────────────────────────
13
- # must come before streamlit or transformers imports write any files
14
- for ENV, VAL in [
15
- ("HOME", "/tmp"),
16
- ("XDG_CONFIG_HOME", "/tmp"),
17
- ("STREAMLIT_HOME", "/tmp"),
18
- ("XDG_CACHE_HOME", "/tmp"),
19
- ("HF_HOME", "/tmp/hf"),
20
- ("TRANSFORMERS_CACHE","/tmp/hf/transformers"),
 
 
 
 
 
 
 
21
  ]:
22
- os.environ[ENV] = VAL
23
 
24
- os.makedirs("/tmp/streamlit", exist_ok=True)
25
- os.makedirs("/tmp/hf/transformers", exist_ok=True)
 
 
 
 
 
 
 
26
 
27
- # ─── YOUR MODEL ID ─────────────────────────────────────────────────────────────
28
  MODEL_ID = "RakeshNJ12345/Chest-Radiology"
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @st.cache_resource(show_spinner=False)
31
- def load_model():
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- fe = ViTFeatureExtractor.from_pretrained(MODEL_ID)
34
- tok = AutoTokenizer.from_pretrained(MODEL_ID)
35
- pipe = VisionEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)
36
- return device, fe, tok, pipe
 
 
 
 
37
 
38
- device, feat_ext, tokenizer, model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # ─── IMAGE PREPROCESSING ───────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  transform = T.Compose([
42
  T.Resize((224, 224)),
43
  T.ToTensor(),
44
  T.Normalize(mean=0.5, std=0.5),
45
  ])
46
 
47
- def load_image(uploaded_file):
48
- """Handle .dcm or normal images uniformly, returns a PIL RGB image."""
49
- name = uploaded_file.name.lower()
50
- if name.endswith(".dcm"):
51
- ds = pydicom.dcmread(uploaded_file)
52
- arr = ds.pixel_array.astype(np.float32)
53
- # normalize to 0–255
54
- arr = (arr - arr.min()) / (arr.max() - arr.min()) * 255.0
55
- arr = arr.astype(np.uint8)
56
- # if monochrome, convert to RGB by stacking
57
- if arr.ndim == 2:
58
- arr = np.stack([arr]*3, axis=-1)
59
- return Image.fromarray(arr)
60
- else:
61
- return Image.open(uploaded_file).convert("RGB")
62
-
63
- # ─── STREAMLIT UI ───────────────────────────────────────────────────────────────
64
  st.set_page_config(page_title="Radiology Report Analysis", layout="wide")
65
  st.markdown("<h1 style='text-align:center;'>🩺 Radiology Report Analysis</h1>", unsafe_allow_html=True)
66
- st.markdown("<p style='text-align:center;'>Upload a chest X-ray (PNG/JPG/JPEG/DCM) and click Generate Report.</p>",
67
- unsafe_allow_html=True)
68
 
69
- if "stage" not in st.session_state:
70
- st.session_state.stage = "upload"
71
-
72
- if st.session_state.stage == "upload":
73
- uploaded = st.file_uploader(
74
- "πŸ“€ Upload your chest X-ray",
75
- type=["png","jpg","jpeg","dcm"],
76
- label_visibility="visible"
77
- )
78
  if uploaded:
79
- st.image(load_image(uploaded), width=350,
80
- caption=f"{uploaded.name} β€” {uploaded.size/1e6:.2f} MB")
81
- if st.button("▢️ Generate Report"):
82
- st.session_state.uploaded = uploaded
83
- st.session_state.stage = "report"
84
- st.experimental_rerun()
85
-
86
- elif st.session_state.stage == "report":
87
- uploaded = st.session_state.uploaded
88
- img = load_image(uploaded)
89
-
90
- with st.spinner("πŸ” Analyzing…"):
91
- # 1) feature extraction
92
- pixel_values = feat_ext(images=img, return_tensors="pt").pixel_values.to(device)
93
- # 2) generation
94
- output_ids = model.generate(
95
- pixel_values,
96
- max_length=64,
97
- num_beams=4,
98
- no_repeat_ngram_size=2,
99
- early_stopping=True,
100
- )
101
- report = tokenizer.decode(output_ids[0], skip_special_tokens=True)
102
-
103
- col1, col2 = st.columns(2)
104
- with col1:
105
- st.subheader("Your Uploaded X-ray")
106
- st.image(img, use_column_width=True)
107
- st.markdown(f"**File:** {uploaded.name} \n**Size:** {uploaded.size/1e6:.2f} MB")
108
- with col2:
109
- st.subheader("πŸ“ AI Diagnosis & Report")
110
- st.markdown(
111
- f"<div style='background:#e0f7fa;padding:12px;border-radius:6px;'>{report}</div>",
112
- unsafe_allow_html=True
113
- )
114
- if st.button("⬅️ Upload Another"):
115
- st.session_state.stage = "upload"
116
- del st.session_state.uploaded
117
- st.experimental_rerun()
118
-
119
- st.markdown("""
120
- <hr>
121
- <p style='text-align:center;color:gray;font-size:0.8em;'>
122
- Powered by your fine-tuned ViT→T5 pipeline on Hugging Face.
123
- </p>
124
- """, unsafe_allow_html=True)
 
1
+ # streamlit_app.py
2
 
3
+ # ──── SET ENVIRONMENT VARIABLES BEFORE ANY IMPORTS ──────────────────────────────
4
  import os
5
+ import tempfile
6
+
7
+ # Create a dedicated cache directory
8
+ CACHE_DIR = "/tmp/hf_cache"
9
+ os.makedirs(CACHE_DIR, exist_ok=True)
10
+
11
+ # Set all relevant environment variables
12
+ os.environ.update({
13
+ "HOME": "/tmp",
14
+ "XDG_CONFIG_HOME": "/tmp",
15
+ "STREAMLIT_HOME": "/tmp/streamlit",
16
+ "XDG_CACHE_HOME": CACHE_DIR,
17
+ "HF_HOME": f"{CACHE_DIR}/huggingface",
18
+ "TRANSFORMERS_CACHE": f"{CACHE_DIR}/transformers",
19
+ "HF_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub",
20
+ "HUGGINGFACE_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub"
21
+ })
22
+
23
+ # Create all cache directories explicitly
24
+ for path in [
25
+ "/tmp/streamlit",
26
+ f"{CACHE_DIR}/huggingface",
27
+ f"{CACHE_DIR}/transformers",
28
+ f"{CACHE_DIR}/huggingface_hub"
29
  ]:
30
+ os.makedirs(path, exist_ok=True)
31
 
32
+ # ──── NOW IMPORT OTHER LIBRARIES ───────────────────────────────────────────────
33
+ import json
34
+ import torch
35
+ import torch.nn as nn
36
+ import torchvision.transforms as T
37
+ import streamlit as st
38
+ from PIL import Image
39
+ from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
40
+ from huggingface_hub import hf_hub_download
41
 
42
+ # ──── MODEL DEFINITION ─────────────────────────────────────────────────────────
43
  MODEL_ID = "RakeshNJ12345/Chest-Radiology"
44
 
45
+ class TwoViewVisionReportModel(nn.Module):
46
+ def __init__(self, vit: ViTModel, t5: T5ForConditionalGeneration, tokenizer: T5Tokenizer):
47
+ super().__init__()
48
+ self.vit = vit
49
+ self.proj_f = nn.Linear(vit.config.hidden_size, t5.config.d_model)
50
+ self.proj_l = nn.Linear(vit.config.hidden_size, t5.config.d_model)
51
+ self.tokenizer = tokenizer
52
+ self.t5 = t5
53
+
54
+ def generate(self, img: torch.Tensor, max_length: int = 64) -> torch.Tensor:
55
+ device = img.device
56
+ vf = self.vit(pixel_values=img).pooler_output
57
+ pf = self.proj_f(vf).unsqueeze(1)
58
+ prefix = pf # single-view only
59
+
60
+ enc = self.tokenizer("report:", return_tensors="pt").to(device)
61
+ txt_emb = self.t5.encoder.embed_tokens(enc.input_ids)
62
+
63
+ enc_emb = torch.cat([prefix, txt_emb], dim=1)
64
+ enc_mask = torch.cat([
65
+ torch.ones(1, 1, device=device, dtype=torch.long),
66
+ enc.attention_mask
67
+ ], dim=1)
68
+
69
+ enc_out = self.t5.encoder(
70
+ inputs_embeds=enc_emb,
71
+ attention_mask=enc_mask
72
+ )
73
+
74
+ out_ids = self.t5.generate(
75
+ encoder_outputs=enc_out,
76
+ encoder_attention_mask=enc_mask,
77
+ max_length=max_length,
78
+ num_beams=1,
79
+ do_sample=False,
80
+ eos_token_id=self.tokenizer.eos_token_id,
81
+ )
82
+ return out_ids
83
+
84
+ # ──── MODEL LOADING WITH ROBUST CACHE HANDLING ─────────────────────────────────
85
  @st.cache_resource(show_spinner=False)
86
+ def load_models():
87
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
+
89
+ # Ensure cache directories exist
90
+ for path in [
91
+ f"{CACHE_DIR}/huggingface",
92
+ f"{CACHE_DIR}/transformers",
93
+ f"{CACHE_DIR}/huggingface_hub"
94
+ ]:
95
+ os.makedirs(path, exist_ok=True)
96
 
97
+ # Download config with explicit cache
98
+ cfg_path = hf_hub_download(
99
+ repo_id=MODEL_ID,
100
+ filename="config.json",
101
+ repo_type="model",
102
+ cache_dir=f"{CACHE_DIR}/huggingface_hub",
103
+ local_files_only=False,
104
+ force_download=True
105
+ )
106
+ cfg = json.load(open(cfg_path, "r"))
107
+
108
+ # Load models with explicit cache directories
109
+ vit = ViTModel.from_pretrained(
110
+ "google/vit-base-patch16-224",
111
+ ignore_mismatched_sizes=True,
112
+ cache_dir=f"{CACHE_DIR}/transformers"
113
+ ).to(device)
114
+
115
+ t5 = T5ForConditionalGeneration.from_pretrained(
116
+ "t5-base",
117
+ cache_dir=f"{CACHE_DIR}/transformers"
118
+ ).to(device)
119
+
120
+ tok = T5Tokenizer.from_pretrained(
121
+ MODEL_ID,
122
+ cache_dir=f"{CACHE_DIR}/transformers"
123
+ )
124
 
125
+ # Load combined model
126
+ model = TwoViewVisionReportModel(vit, t5, tok).to(device)
127
+
128
+ ckpt_path = hf_hub_download(
129
+ repo_id=MODEL_ID,
130
+ filename="pytorch_model.bin",
131
+ repo_type="model",
132
+ cache_dir=f"{CACHE_DIR}/huggingface_hub",
133
+ local_files_only=False,
134
+ force_download=True
135
+ )
136
+
137
+ state = torch.load(ckpt_path, map_location=device)
138
+ model.load_state_dict(state)
139
+ return device, model, tok
140
+
141
+ # ──── APP INTERFACE ───────────────────────────────────────────────────────────
142
+ device, model, tokenizer = load_models()
143
  transform = T.Compose([
144
  T.Resize((224, 224)),
145
  T.ToTensor(),
146
  T.Normalize(mean=0.5, std=0.5),
147
  ])
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  st.set_page_config(page_title="Radiology Report Analysis", layout="wide")
150
  st.markdown("<h1 style='text-align:center;'>🩺 Radiology Report Analysis</h1>", unsafe_allow_html=True)
151
+ st.markdown("<p style='text-align:center;'>Upload a chest X-ray and click Generate Report.</p>", unsafe_allow_html=True)
 
152
 
153
+ # File upload handling
154
+ if "img" not in st.session_state:
155
+ uploaded = st.file_uploader("πŸ“€ Upload X-ray (PNG/JPG)", type=["png", "jpg", "jpeg"])
 
 
 
 
 
 
156
  if uploaded:
157
+ st.session_state.img = uploaded
158
+ st.experimental_rerun()
159
+ else:
160
+ st.stop()
161
+
162
+ img_file = st.session_state.img
163
+ img = Image.open(img_file).convert("RGB")
164
+ st.image(img, use_column_width=True)
165
+
166
+ col1, col2 = st.columns(2)
167
+ with col1:
168
+ if st.button("▢️ Generate Report", use_container_width=True):
169
+ with st.spinner("Analyzing X-ray..."):
170
+ px = transform(img).unsqueeze(0).to(device)
171
+ out_ids = model.generate(px, max_length=128)
172
+ report = tokenizer.decode(out_ids[0], skip_special_tokens=True)
173
+
174
+ st.subheader("πŸ“ AI-Generated Report")
175
+ st.success(report)
176
+
177
+ with col2:
178
+ if st.button("⬅️ Upload Another", use_container_width=True):
179
+ del st.session_state.img
180
+ st.experimental_rerun()