RakeshNJ12345 commited on
Commit
2ca597d
Β·
verified Β·
1 Parent(s): 393cbcb

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -6
src/streamlit_app.py CHANGED
@@ -24,19 +24,31 @@ from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
24
  # ─── point at your 1.2 GB model repo, NOT this Space —──────────────────────────
25
  HF_MODEL_ID = "RakeshNJ12345/Chest-Radiology"
26
 
 
27
  @st.cache_resource(show_spinner=False)
28
  def load_models():
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
- # 1) Vision trunk (fine‐tuned ViT weights under `vit/` in your model repo)
32
- vit = ViTModel.from_pretrained(f"{HF_MODEL_ID}/vit").to(device)
33
-
34
- # 2) T5 + tokenizer (your fine‐tuned report generator)
35
- t5 = T5ForConditionalGeneration.from_pretrained(HF_MODEL_ID).to(device)
36
- tok = T5Tokenizer.from_pretrained(HF_MODEL_ID)
 
 
 
 
 
 
 
 
 
 
37
 
38
  return device, vit, t5, tok
39
 
 
40
  device, vit, t5, tokenizer = load_models()
41
 
42
  # ─── preprocessing for ViT —────────────────────────────────────────────────────
 
24
  # ─── point at your 1.2 GB model repo, NOT this Space —──────────────────────────
25
  HF_MODEL_ID = "RakeshNJ12345/Chest-Radiology"
26
 
27
+ @st.cache_resource(show_spinner=False)
28
  @st.cache_resource(show_spinner=False)
29
  def load_models():
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
+ # 1) Frozen ViT pulled from the "vit/" folder of your model repo
33
+ vit = ViTModel.from_pretrained(
34
+ HF_MODEL_ID,
35
+ subfolder="vit",
36
+ local_files_only=True,
37
+ ).to(device)
38
+
39
+ # 2) Fine-tuned T5 & tokenizer at the root of that same repo
40
+ t5 = T5ForConditionalGeneration.from_pretrained(
41
+ HF_MODEL_ID,
42
+ local_files_only=True,
43
+ ).to(device)
44
+ tok = T5Tokenizer.from_pretrained(
45
+ HF_MODEL_ID,
46
+ local_files_only=True,
47
+ )
48
 
49
  return device, vit, t5, tok
50
 
51
+
52
  device, vit, t5, tokenizer = load_models()
53
 
54
  # ─── preprocessing for ViT —────────────────────────────────────────────────────