RakeshNJ12345 commited on
Commit
e7abd9d
Β·
verified Β·
1 Parent(s): 50df4a4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +124 -43
src/streamlit_app.py CHANGED
@@ -3,6 +3,7 @@
3
  # ──── SET ENVIRONMENT VARIABLES BEFORE ANY IMPORTS ──────────────────────────────
4
  import os
5
  import tempfile
 
6
 
7
  # Create dedicated cache directories
8
  CACHE_DIR = "/tmp/hf_cache"
@@ -14,12 +15,13 @@ os.makedirs(STREAMLIT_DIR, exist_ok=True)
14
  os.environ.update({
15
  "HOME": "/tmp",
16
  "XDG_CONFIG_HOME": "/tmp",
17
- "STREAMLIT_HOME": STREAMLIT_DIR, # Force Streamlit to use our temp directory
18
  "XDG_CACHE_HOME": CACHE_DIR,
19
  "HF_HOME": f"{CACHE_DIR}/huggingface",
20
  "TRANSFORMERS_CACHE": f"{CACHE_DIR}/transformers",
21
  "HF_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub",
22
- "HUGGINGFACE_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub"
 
23
  })
24
 
25
  # Create all cache directories explicitly
@@ -46,11 +48,14 @@ import torchvision.transforms as T
46
  import streamlit as st
47
  from PIL import Image
48
  from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
49
- from huggingface_hub import hf_hub_download
50
 
51
  # ──── MODEL DEFINITION ─────────────────────────────────────────────────────────
52
  MODEL_ID = "RakeshNJ12345/Chest-Radiology"
53
 
 
 
 
54
  class TwoViewVisionReportModel(nn.Module):
55
  def __init__(self, vit: ViTModel, t5: T5ForConditionalGeneration, tokenizer: T5Tokenizer):
56
  super().__init__()
@@ -90,7 +95,7 @@ class TwoViewVisionReportModel(nn.Module):
90
  )
91
  return out_ids
92
 
93
- # ──── MODEL LOADING WITH ROBUST CACHE HANDLING ─────────────────────────────────
94
  @st.cache_resource(show_spinner=False)
95
  def load_models():
96
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -103,50 +108,106 @@ def load_models():
103
  ]:
104
  os.makedirs(path, exist_ok=True)
105
 
106
- # Download config with explicit cache
107
- cfg_path = hf_hub_download(
108
- repo_id=MODEL_ID,
109
- filename="config.json",
110
- repo_type="model",
111
- cache_dir=f"{CACHE_DIR}/huggingface_hub",
112
- local_files_only=False
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  cfg = json.load(open(cfg_path, "r"))
115
 
116
  # Load models with explicit cache directories
117
- vit = ViTModel.from_pretrained(
118
- "google/vit-base-patch16-224",
119
- ignore_mismatched_sizes=True,
120
- cache_dir=f"{CACHE_DIR}/transformers"
121
- ).to(device)
 
 
 
 
 
 
 
 
 
122
 
123
- t5 = T5ForConditionalGeneration.from_pretrained(
124
- "t5-base",
125
- cache_dir=f"{CACHE_DIR}/transformers"
126
- ).to(device)
 
 
 
 
 
 
 
127
 
128
- tok = T5Tokenizer.from_pretrained(
129
- MODEL_ID,
130
- cache_dir=f"{CACHE_DIR}/transformers"
131
- )
 
 
 
 
 
 
 
132
 
133
  # Load combined model
134
  model = TwoViewVisionReportModel(vit, t5, tok).to(device)
135
 
136
- ckpt_path = hf_hub_download(
137
- repo_id=MODEL_ID,
138
- filename="pytorch_model.bin",
139
- repo_type="model",
140
- cache_dir=f"{CACHE_DIR}/huggingface_hub",
141
- local_files_only=False
142
- )
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  state = torch.load(ckpt_path, map_location=device)
145
  model.load_state_dict(state)
146
  return device, model, tok
147
 
148
- # ──── APP INTERFACE ───────────────────────────────────────────────────────────
149
- device, model, tokenizer = load_models()
 
 
 
 
 
150
  transform = T.Compose([
151
  T.Resize((224, 224)),
152
  T.ToTensor(),
@@ -161,8 +222,15 @@ st.markdown("<p style='text-align:center;'>Upload a chest X-ray and click Genera
161
  if "img" not in st.session_state:
162
  uploaded = st.file_uploader("πŸ“€ Upload X-ray (PNG/JPG)", type=["png", "jpg", "jpeg"])
163
  if uploaded:
164
- st.session_state.img = uploaded
165
- st.experimental_rerun()
 
 
 
 
 
 
 
166
  else:
167
  st.stop()
168
 
@@ -174,14 +242,27 @@ col1, col2 = st.columns(2)
174
  with col1:
175
  if st.button("▢️ Generate Report", use_container_width=True):
176
  with st.spinner("Analyzing X-ray..."):
177
- px = transform(img).unsqueeze(0).to(device)
178
- out_ids = model.generate(px, max_length=128)
179
- report = tokenizer.decode(out_ids[0], skip_special_tokens=True)
180
-
181
- st.subheader("πŸ“ AI-Generated Report")
182
- st.success(report)
 
 
 
 
183
 
184
  with col2:
185
  if st.button("⬅️ Upload Another", use_container_width=True):
186
  del st.session_state.img
187
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
3
  # ──── SET ENVIRONMENT VARIABLES BEFORE ANY IMPORTS ──────────────────────────────
4
  import os
5
  import tempfile
6
+ import requests
7
 
8
  # Create dedicated cache directories
9
  CACHE_DIR = "/tmp/hf_cache"
 
15
  os.environ.update({
16
  "HOME": "/tmp",
17
  "XDG_CONFIG_HOME": "/tmp",
18
+ "STREAMLIT_HOME": STREAMLIT_DIR,
19
  "XDG_CACHE_HOME": CACHE_DIR,
20
  "HF_HOME": f"{CACHE_DIR}/huggingface",
21
  "TRANSFORMERS_CACHE": f"{CACHE_DIR}/transformers",
22
  "HF_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub",
23
+ "HUGGINGFACE_HUB_CACHE": f"{CACHE_DIR}/huggingface_hub",
24
+ "HF_HUB_DISABLE_TELEMETRY": "1" # Disable telemetry to reduce rate limiting
25
  })
26
 
27
  # Create all cache directories explicitly
 
48
  import streamlit as st
49
  from PIL import Image
50
  from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
51
+ from huggingface_hub import hf_hub_download, HfApi
52
 
53
  # ──── MODEL DEFINITION ─────────────────────────────────────────────────────────
54
  MODEL_ID = "RakeshNJ12345/Chest-Radiology"
55
 
56
+ # Alternative model access through proxy
57
+ PROXY_URL = "https://hf-mirror.com"
58
+
59
  class TwoViewVisionReportModel(nn.Module):
60
  def __init__(self, vit: ViTModel, t5: T5ForConditionalGeneration, tokenizer: T5Tokenizer):
61
  super().__init__()
 
95
  )
96
  return out_ids
97
 
98
+ # ──── MODEL LOADING WITH PROXY SUPPORT AND ERROR HANDLING ──────────────────────
99
  @st.cache_resource(show_spinner=False)
100
  def load_models():
101
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
108
  ]:
109
  os.makedirs(path, exist_ok=True)
110
 
111
+ # Try to download using standard method first
112
+ try:
113
+ # Download config
114
+ cfg_path = hf_hub_download(
115
+ repo_id=MODEL_ID,
116
+ filename="config.json",
117
+ repo_type="model",
118
+ cache_dir=f"{CACHE_DIR}/huggingface_hub",
119
+ local_files_only=False
120
+ )
121
+ except Exception as e:
122
+ st.error(f"❌ Failed to download model via Hugging Face Hub: {str(e)}")
123
+ st.info("⚠️ Trying alternative download method...")
124
+
125
+ # Use proxy mirror
126
+ cfg_path = f"{CACHE_DIR}/huggingface_hub/config.json"
127
+ api = HfApi(endpoint=PROXY_URL)
128
+ api.hf_hub_download(
129
+ repo_id=MODEL_ID,
130
+ filename="config.json",
131
+ repo_type="model",
132
+ cache_dir=f"{CACHE_DIR}/huggingface_hub",
133
+ local_files_only=False
134
+ )
135
+
136
  cfg = json.load(open(cfg_path, "r"))
137
 
138
  # Load models with explicit cache directories
139
+ try:
140
+ vit = ViTModel.from_pretrained(
141
+ "google/vit-base-patch16-224",
142
+ ignore_mismatched_sizes=True,
143
+ cache_dir=f"{CACHE_DIR}/transformers"
144
+ ).to(device)
145
+ except:
146
+ # Use proxy if standard download fails
147
+ vit = ViTModel.from_pretrained(
148
+ "google/vit-base-patch16-224",
149
+ ignore_mismatched_sizes=True,
150
+ cache_dir=f"{CACHE_DIR}/transformers",
151
+ mirror=PROXY_URL
152
+ ).to(device)
153
 
154
+ try:
155
+ t5 = T5ForConditionalGeneration.from_pretrained(
156
+ "t5-base",
157
+ cache_dir=f"{CACHE_DIR}/transformers"
158
+ ).to(device)
159
+ except:
160
+ t5 = T5ForConditionalGeneration.from_pretrained(
161
+ "t5-base",
162
+ cache_dir=f"{CACHE_DIR}/transformers",
163
+ mirror=PROXY_URL
164
+ ).to(device)
165
 
166
+ try:
167
+ tok = T5Tokenizer.from_pretrained(
168
+ MODEL_ID,
169
+ cache_dir=f"{CACHE_DIR}/transformers"
170
+ )
171
+ except:
172
+ tok = T5Tokenizer.from_pretrained(
173
+ MODEL_ID,
174
+ cache_dir=f"{CACHE_DIR}/transformers",
175
+ mirror=PROXY_URL
176
+ )
177
 
178
  # Load combined model
179
  model = TwoViewVisionReportModel(vit, t5, tok).to(device)
180
 
181
+ try:
182
+ ckpt_path = hf_hub_download(
183
+ repo_id=MODEL_ID,
184
+ filename="pytorch_model.bin",
185
+ repo_type="model",
186
+ cache_dir=f"{CACHE_DIR}/huggingface_hub",
187
+ local_files_only=False
188
+ )
189
+ except:
190
+ # Use proxy mirror for model weights
191
+ api = HfApi(endpoint=PROXY_URL)
192
+ ckpt_path = api.hf_hub_download(
193
+ repo_id=MODEL_ID,
194
+ filename="pytorch_model.bin",
195
+ repo_type="model",
196
+ cache_dir=f"{CACHE_DIR}/huggingface_hub",
197
+ local_files_only=False
198
+ )
199
 
200
  state = torch.load(ckpt_path, map_location=device)
201
  model.load_state_dict(state)
202
  return device, model, tok
203
 
204
+ # ──── APP INTERFACE WITH ERROR HANDLING ───────────────────────────────────────
205
+ try:
206
+ device, model, tokenizer = load_models()
207
+ except Exception as e:
208
+ st.error(f"🚨 Critical Error: Failed to load models. {str(e)}")
209
+ st.stop()
210
+
211
  transform = T.Compose([
212
  T.Resize((224, 224)),
213
  T.ToTensor(),
 
222
  if "img" not in st.session_state:
223
  uploaded = st.file_uploader("πŸ“€ Upload X-ray (PNG/JPG)", type=["png", "jpg", "jpeg"])
224
  if uploaded:
225
+ try:
226
+ # Validate image
227
+ img = Image.open(uploaded).convert("RGB")
228
+ img.verify() # Check if image is valid
229
+ st.session_state.img = uploaded
230
+ st.experimental_rerun()
231
+ except Exception as e:
232
+ st.error(f"❌ Invalid image file: {str(e)}")
233
+ st.stop()
234
  else:
235
  st.stop()
236
 
 
242
  with col1:
243
  if st.button("▢️ Generate Report", use_container_width=True):
244
  with st.spinner("Analyzing X-ray..."):
245
+ try:
246
+ px = transform(img).unsqueeze(0).to(device)
247
+ out_ids = model.generate(px, max_length=128)
248
+ report = tokenizer.decode(out_ids[0], skip_special_tokens=True)
249
+
250
+ st.subheader("πŸ“ AI-Generated Report")
251
+ st.success(report)
252
+ except Exception as e:
253
+ st.error(f"❌ Analysis failed: {str(e)}")
254
+ st.info("Please try with a different image or try again later")
255
 
256
  with col2:
257
  if st.button("⬅️ Upload Another", use_container_width=True):
258
  del st.session_state.img
259
+ st.experimental_rerun()
260
+
261
+ # Add footer with troubleshooting
262
+ st.markdown("---")
263
+ st.markdown("""
264
+ **Troubleshooting Tips:**
265
+ - If model download fails, wait 5 minutes and refresh
266
+ - Use standard chest X-ray images in PNG or JPG format
267
+ - For persistent errors, contact support@example.com
268
+ """)