RakeshNJ12345 commited on
Commit
0bdd3ac
Β·
verified Β·
1 Parent(s): e2aeef9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -32
app.py CHANGED
@@ -6,15 +6,15 @@ import torch.nn as nn
6
  import numpy as np
7
  from PIL import Image
8
  from transformers import (
9
- ViTConfig, ViTModel,
10
- BartConfig, BartForConditionalGeneration, AutoTokenizer
11
  )
12
  from huggingface_hub import hf_hub_download
13
- import safetensors.torch as st
14
  import gradio as gr
15
 
16
- # ───────────────────────── Helpers ─────────────────────────
17
  def build_bad_words_ids(tok: AutoTokenizer):
 
18
  bad_phrases = [
19
  "XXXX", "xxxx", "X-XXXX", "x-XXXX", "x - XXXX", "x -xxxx", "x-xxxx",
20
  "X - XXXX", "xβ€”XXXX", "x–XXXX", "x β€” XXXX", "x – XXXX",
@@ -25,9 +25,10 @@ def build_bad_words_ids(tok: AutoTokenizer):
25
  ids = tok(phrase, add_special_tokens=False).input_ids
26
  if ids and not all(i == tok.unk_token_id for i in ids):
27
  bad_ids.append(ids)
28
- return bad_ids or None
29
 
30
  def normalize_report(text: str) -> str:
 
31
  if not text:
32
  return text
33
  text = re.sub(r'\bx\s*[-–—]?\s*xxxx\b', 'x-ray', text, flags=re.IGNORECASE)
@@ -40,32 +41,20 @@ def normalize_report(text: str) -> str:
40
  # ─── 1) MODEL LOADING ─────────────────────────────────────────────────────────
41
  def load_model():
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
- repo_id = "RakeshNJ12345/MMic-CXR"
 
44
 
45
- base_path = "mimic_trained/final" # βœ… correct subfolder root
 
 
 
46
 
47
- # 1a) load ViT encoder manually
48
- vit_cfg = ViTConfig.from_pretrained(repo_id, subfolder=f"{base_path}/vit")
49
- vit = ViTModel(vit_cfg)
50
- vit_path = hf_hub_download(repo_id, filename="model.safetensors", subfolder=f"{base_path}/vit")
51
- vit_state = st.load_file(vit_path)
52
- vit.load_state_dict(vit_state)
53
- vit = vit.to(device)
54
-
55
- # 1b) load decoder (BioBART) manually
56
- dec_cfg = BartConfig.from_pretrained(repo_id, subfolder=f"{base_path}/decoder")
57
- dec = BartForConditionalGeneration(dec_cfg)
58
- dec_path = hf_hub_download(repo_id, filename="model.safetensors", subfolder=f"{base_path}/decoder")
59
- dec_state = st.load_file(dec_path)
60
- dec.load_state_dict(dec_state)
61
- dec = dec.to(device)
62
-
63
-
64
- # 1c) tokenizer
65
  tok = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{base_path}/decoder")
66
  tok.clean_up_tokenization_spaces = True
67
 
68
- # 1d) load projection head
69
  proj_path = hf_hub_download(repo_id=repo_id, filename="proj.bin", subfolder=base_path)
70
  loaded = torch.load(proj_path, map_location=device)
71
  if isinstance(loaded, dict):
@@ -73,13 +62,13 @@ def load_model():
73
  proj = nn.Linear(vit.config.hidden_size, dec.config.d_model)
74
  proj.load_state_dict(sd)
75
  else:
76
- proj = loaded
77
  proj = proj.to(device)
78
 
79
- # 1e) blocklist
80
  bad_words_ids = build_bad_words_ids(tok)
81
 
82
- # 1f) wrapper model
83
  class TwoViewModel(nn.Module):
84
  def __init__(self, vit, dec, proj, tok, bad_words_ids=None):
85
  super().__init__()
@@ -94,13 +83,16 @@ def load_model():
94
 
95
  def generate(self, img_f, img_l, finds, max_len=128, num_beams=4):
96
  device = img_f.device
97
- # extract [CLS]
 
98
  out_f = self.vit(pixel_values=img_f).last_hidden_state[:, 0]
99
  out_l = self.vit(pixel_values=img_l).last_hidden_state[:, 0]
 
 
100
  avg = 0.5 * (out_f + out_l)
101
  prefix = self.proj(avg).unsqueeze(1) # [B,1,D]
102
 
103
- # prepend findings if available
104
  if (finds or "").strip():
105
  enc = self.tok(finds, return_tensors="pt", padding=True, truncation=True).to(device)
106
  text_emb = self.dec.get_encoder().embed_tokens(enc.input_ids)
@@ -133,6 +125,7 @@ model, tokenizer, device = load_model()
133
 
134
  # ─── 2) PREPROCESS ─────────────────────────────────────────────────────────────
135
  def preprocess(img: Image.Image) -> torch.Tensor:
 
136
  img = img.convert("RGB").resize((224, 224))
137
  arr = np.array(img).astype(np.float32) / 255.0
138
  if arr.ndim == 2:
@@ -155,7 +148,7 @@ def generate_report(frontal, lateral, findings, beams, max_len):
155
  f_t, l_t, findings or "", max_len=max_len, num_beams=beams
156
  )
157
  text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
158
- text = normalize_report(text)
159
  return text or "<empty>"
160
  except Exception as e:
161
  traceback.print_exc()
 
6
  import numpy as np
7
  from PIL import Image
8
  from transformers import (
9
+ AutoConfig, AutoModel, # ← handles dinov2/ViT automatically
10
+ AutoModelForSeq2SeqLM, AutoTokenizer
11
  )
12
  from huggingface_hub import hf_hub_download
 
13
  import gradio as gr
14
 
15
+ # ───────────────────────── Helpers: blocklist + text normalizer ─────────────────────────
16
  def build_bad_words_ids(tok: AutoTokenizer):
17
+ """Build token id sequences to block anonymization artifacts."""
18
  bad_phrases = [
19
  "XXXX", "xxxx", "X-XXXX", "x-XXXX", "x - XXXX", "x -xxxx", "x-xxxx",
20
  "X - XXXX", "xβ€”XXXX", "x–XXXX", "x β€” XXXX", "x – XXXX",
 
25
  ids = tok(phrase, add_special_tokens=False).input_ids
26
  if ids and not all(i == tok.unk_token_id for i in ids):
27
  bad_ids.append(ids)
28
+ return bad_ids or None # HF expects None if empty
29
 
30
  def normalize_report(text: str) -> str:
31
+ """Cleanup on generated text to replace/remove anonymization placeholders."""
32
  if not text:
33
  return text
34
  text = re.sub(r'\bx\s*[-–—]?\s*xxxx\b', 'x-ray', text, flags=re.IGNORECASE)
 
41
  # ─── 1) MODEL LOADING ─────────────────────────────────────────────────────────
42
  def load_model():
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ repo_id = "RakeshNJ12345/MMic-CXR"
45
+ base_path = "mimic_trained/final" # ← your confirmed path
46
 
47
+ # 1a) Encoder (DINOv2/ViT via AutoModel so config decides the class)
48
+ enc_cfg = AutoConfig.from_pretrained(repo_id, subfolder=f"{base_path}/vit")
49
+ vit = AutoModel.from_pretrained(repo_id, subfolder=f"{base_path}/vit").to(device)
50
+ # enc_cfg.model_type will be 'dinov2' in your case; vit.config.hidden_size is available.
51
 
52
+ # 1b) Decoder & tokenizer (BioBART)
53
+ dec = AutoModelForSeq2SeqLM.from_pretrained(repo_id, subfolder=f"{base_path}/decoder").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  tok = AutoTokenizer.from_pretrained(repo_id, subfolder=f"{base_path}/decoder")
55
  tok.clean_up_tokenization_spaces = True
56
 
57
+ # 1c) Projection head
58
  proj_path = hf_hub_download(repo_id=repo_id, filename="proj.bin", subfolder=base_path)
59
  loaded = torch.load(proj_path, map_location=device)
60
  if isinstance(loaded, dict):
 
62
  proj = nn.Linear(vit.config.hidden_size, dec.config.d_model)
63
  proj.load_state_dict(sd)
64
  else:
65
+ proj = loaded # if you saved an nn.Linear directly
66
  proj = proj.to(device)
67
 
68
+ # 1d) Blocklist for anonymization artifacts
69
  bad_words_ids = build_bad_words_ids(tok)
70
 
71
+ # 1e) Wrapper
72
  class TwoViewModel(nn.Module):
73
  def __init__(self, vit, dec, proj, tok, bad_words_ids=None):
74
  super().__init__()
 
83
 
84
  def generate(self, img_f, img_l, finds, max_len=128, num_beams=4):
85
  device = img_f.device
86
+
87
+ # CLS embeddings from both views
88
  out_f = self.vit(pixel_values=img_f).last_hidden_state[:, 0]
89
  out_l = self.vit(pixel_values=img_l).last_hidden_state[:, 0]
90
+
91
+ # average + project β†’ prefix embedding
92
  avg = 0.5 * (out_f + out_l)
93
  prefix = self.proj(avg).unsqueeze(1) # [B,1,D]
94
 
95
+ # prepend findings text (optional)
96
  if (finds or "").strip():
97
  enc = self.tok(finds, return_tensors="pt", padding=True, truncation=True).to(device)
98
  text_emb = self.dec.get_encoder().embed_tokens(enc.input_ids)
 
125
 
126
  # ─── 2) PREPROCESS ─────────────────────────────────────────────────────────────
127
  def preprocess(img: Image.Image) -> torch.Tensor:
128
+ # Basic resize + [0,1] scaling; works across ViT/DINOv2
129
  img = img.convert("RGB").resize((224, 224))
130
  arr = np.array(img).astype(np.float32) / 255.0
131
  if arr.ndim == 2:
 
148
  f_t, l_t, findings or "", max_len=max_len, num_beams=beams
149
  )
150
  text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
151
+ text = normalize_report(text) # cleanup anonymization artifacts
152
  return text or "<empty>"
153
  except Exception as e:
154
  traceback.print_exc()