RakeshNJ12345 commited on
Commit
cf005fe
Β·
verified Β·
1 Parent(s): 0c40d86

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +160 -38
streamlit_app.py CHANGED
@@ -1,40 +1,162 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
4
  import streamlit as st
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as T
6
  import streamlit as st
7
+ from PIL import Image
8
+ from transformers import (
9
+ ViTConfig, ViTModel,
10
+ T5ForConditionalGeneration,
11
+ T5Tokenizer,
12
+ )
13
 
14
+ # ─── FORCE ALL CACHE & CONFIG INTO /tmp ─────────────────────────────────────
15
+ for ENV, VAL in [
16
+ ("HOME", "/tmp"),
17
+ ("XDG_CONFIG_HOME", "/tmp"),
18
+ ("STREAMLIT_HOME", "/tmp"),
19
+ ("XDG_CACHE_HOME", "/tmp"),
20
+ ("HF_HOME", "/tmp/hf"),
21
+ ("TRANSFORMERS_CACHE", "/tmp/hf/transformers"),
22
+ ]:
23
+ os.environ[ENV] = VAL
24
+ os.makedirs("/tmp/streamlit", exist_ok=True)
25
+ os.makedirs("/tmp/hf/transformers", exist_ok=True)
26
+
27
+
28
+ # ─── YOUR HF MODEL REPO ───────────────────────────────────────────────────────
29
+ HF_MODEL_ID = "RakeshNJ12345/Chest-Radiology"
30
+
31
+
32
+ @st.cache_resource(show_spinner=False)
33
+ def load_models():
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # 1) VIT: load its config, build fresh, then we'll load YOUR weights into it
37
+ vit_cfg = ViTConfig.from_pretrained("google/vit-base-patch16-224")
38
+ vit = ViTModel(vit_cfg)
39
+
40
+ # 2) T5 + tokenizer: same idea, fresh + load YOUR weights
41
+ t5 = T5ForConditionalGeneration.from_pretrained("t5-base")
42
+ tok = T5Tokenizer.from_pretrained(HF_MODEL_ID)
43
+
44
+ # 3) grab the single combined file from your repo
45
+ state = torch.hub.load_state_dict_from_url(
46
+ f"https://huggingface.co/{HF_MODEL_ID}/resolve/main/pytorch_model.bin",
47
+ map_location="cpu", check_hash=True
48
+ )
49
+
50
+ # 4) split into vit vs t5 state_dicts
51
+ vit_state = {k[len("vit."):]: v for k,v in state.items() if k.startswith("vit.")}
52
+ t5_state = {k[len("t5."):]: v for k,v in state.items() if k.startswith("t5.")}
53
+
54
+ # 5) load them
55
+ vit.load_state_dict(vit_state, strict=False)
56
+ t5.load_state_dict(t5_state, strict=False)
57
+
58
+ # 6) move to device & eval
59
+ vit.to(device).eval()
60
+ t5.to(device).eval()
61
+
62
+ return device, vit, t5, tok
63
+
64
+
65
+ device, vit, t5, tokenizer = load_models()
66
+
67
+
68
+ # ─── IMAGE PREPROCESSING ─────────────────────────────────────────────────────
69
+ transform = T.Compose([
70
+ T.Resize((224, 224)),
71
+ T.ToTensor(),
72
+ T.Normalize(mean=0.5, std=0.5),
73
+ ])
74
+
75
+
76
+ # ─── STREAMLIT LAYOUT ────────────────────────────────────────────────────────
77
+ st.set_page_config(page_title="Radiology Report Analysis", layout="wide")
78
+ st.markdown("<h1 style='text-align:center;'>🩺 Radiology Report Analysis</h1>",
79
+ unsafe_allow_html=True)
80
+ st.markdown(
81
+ "<p style='text-align:center;'>Upload a chest X-ray (PNG/JPG) to generate an AI report.</p>",
82
+ unsafe_allow_html=True
83
+ )
84
+
85
+ if "stage" not in st.session_state:
86
+ st.session_state.stage = "upload"
87
+
88
+
89
+ # ─── UPLOAD SCREEN ───────────────────────────────────────────────────────────
90
+ if st.session_state.stage == "upload":
91
+ up = st.file_uploader("", type=["png","jpg","jpeg"], label_visibility="collapsed")
92
+ if up:
93
+ st.image(up, width=350, caption=f"{up.name} β€” {up.size/1e6:.2f} MB")
94
+ if st.button("▢️ Generate Report"):
95
+ st.session_state.uploaded = up
96
+ st.session_state.stage = "report"
97
+ st.experimental_rerun()
98
+
99
+
100
+ # ─── REPORT SCREEN ───────────────────���───────────────────────────────────────
101
+ elif st.session_state.stage == "report":
102
+ img = Image.open(st.session_state.uploaded).convert("RGB")
103
+
104
+ with st.spinner("πŸ”Ž Analyzing…"):
105
+ # 1) ViT features
106
+ x = transform(img).unsqueeze(0).to(device)
107
+ vfeat = vit(pixel_values=x).pooler_output # [1,768]
108
+
109
+ # 2) project into T5’s hidden size
110
+ proj = nn.Linear(vfeat.size(-1), t5.config.d_model).to(device)
111
+ prefix = proj(vfeat).unsqueeze(1) # [1,1,d_model]
112
+
113
+ # 3) β€œreport:” token embeddings
114
+ enc = tokenizer("report:", return_tensors="pt").to(device)
115
+ txt_emb = t5.encoder.embed_tokens(enc.input_ids) # [1,L,d_model]
116
+
117
+ # 4) concat + mask
118
+ emb = torch.cat([prefix, txt_emb], dim=1) # [1,1+L,d]
119
+ mask = torch.cat([
120
+ torch.ones(1,1,device=device),
121
+ enc.attention_mask
122
+ ], dim=1) # [1,1+L]
123
+
124
+ # 5) encode + generate
125
+ enc_out = t5.encoder(inputs_embeds=emb, attention_mask=mask)
126
+ ids = t5.generate(
127
+ encoder_outputs = enc_out,
128
+ encoder_attention_mask = mask,
129
+ max_length = 64,
130
+ num_beams = 1,
131
+ do_sample = False,
132
+ eos_token_id = tokenizer.eos_token_id,
133
+ )
134
+ report = tokenizer.decode(ids[0], skip_special_tokens=True)
135
+
136
+ # ── DISPLAY ────────────────────────────────────────────────────────────────
137
+ c1, c2 = st.columns(2)
138
+ with c1:
139
+ st.subheader("Your Uploaded X-ray")
140
+ st.image(img, use_column_width=True)
141
+ st.markdown(
142
+ f"**File:** {st.session_state.uploaded.name} \n"
143
+ f"**Size:** {st.session_state.uploaded.size/1e6:.2f} MB"
144
+ )
145
+ with c2:
146
+ st.subheader("AI Diagnosis & Report")
147
+ st.markdown(
148
+ f"<div style='background:#e0f7fa;padding:12px;border-radius:6px;'>"
149
+ f"<strong>Primary Diagnosis</strong><br>{report}</div>",
150
+ unsafe_allow_html=True
151
+ )
152
+ if st.button("⬅️ Upload Another"):
153
+ st.session_state.stage = "upload"
154
+ del st.session_state.uploaded
155
+ st.experimental_rerun()
156
+
157
+ st.markdown("""
158
+ <hr style='margin:2em 0;'>
159
+ <p style='font-size:0.8em;color:gray;text-align:center;'>
160
+ Powered by your fine-tuned ViTβž”T5, both coming from a single pytorch_model.bin in Chest-Radiology.
161
+ </p>
162
+ """, unsafe_allow_html=True)