hackergeek commited on
Commit
8bbc742
·
verified ·
1 Parent(s): 8df6dbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -185
app.py CHANGED
@@ -1,206 +1,84 @@
1
- # =====================================================
2
- # Gradio Radiology Captioner (with VQA-ready model)
3
- # Loads epoch_04 checkpoint from Hugging Face
4
- # =====================================================
5
-
6
  import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
  from torchvision import transforms
10
  from PIL import Image
11
- import math
12
  import gradio as gr
13
- from huggingface_hub import hf_hub_download
14
- import numpy as np
15
- import pydicom
16
- import nibabel as nib
17
-
18
- # ======================
19
- # Device & dtype
20
- # ======================
21
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
- DTYPE = torch.bfloat16 if DEVICE=="cuda" else torch.float32
23
-
24
- # ======================
25
- # Tokenizer
26
- # ======================
27
  from transformers import AutoTokenizer
28
- tokenizer = AutoTokenizer.from_pretrained("microsoft/biogpt")
29
- tokenizer.pad_token = tokenizer.eos_token
30
- VOCAB_SIZE = tokenizer.vocab_size
31
- MAX_SEQ_LEN = 192
32
 
33
- # ======================
34
- # Conversation formatting
35
- # ======================
36
- IM_START = "<|im_start|>"
37
- IM_END = "<|im_end|>"
38
 
39
- def format_conversation(conversations):
40
- text = ""
41
- for msg in conversations:
42
- role = msg.get("from", msg.get("role", "assistant"))
43
- if role == "human":
44
- role = "user"
45
- content = msg.get("value", msg.get("content", ""))
46
- text += f"{IM_START}{role}\n{content.strip()}{IM_END}\n"
47
- return text.strip()
48
-
49
- # ======================
50
- # Model definition
51
- # ======================
52
- class ConvBlock(nn.Module):
53
- def __init__(self, dim_in, dim_out):
54
- super().__init__()
55
- self.dwconv = nn.Conv2d(dim_in, dim_in, 3, padding=1, groups=dim_in)
56
- self.norm = nn.LayerNorm(dim_in)
57
- self.pw1 = nn.Linear(dim_in, 4*dim_in)
58
- self.act = nn.GELU()
59
- self.pw2 = nn.Linear(4*dim_in, dim_out)
60
- self.shortcut = nn.Conv2d(dim_in, dim_out, 1) if dim_in!=dim_out else nn.Identity()
61
- def forward(self, x):
62
- res = self.shortcut(x)
63
- x = self.dwconv(x)
64
- x = x.permute(0,2,3,1)
65
- x = self.norm(x)
66
- x = x.permute(0,3,1,2)
67
- x = x.flatten(2).transpose(1,2)
68
- x = self.pw1(x)
69
- x = self.act(x)
70
- x = self.pw2(x)
71
- x = x.transpose(1,2).view(res.shape)
72
- return res + x
73
 
74
- class CNNEncoder(nn.Module):
75
- def __init__(self, dims=[96,192,384]):
76
- super().__init__()
77
- self.stem = nn.Sequential(
78
- nn.Conv2d(3,dims[0],4,4),
79
- nn.BatchNorm2d(dims[0]),
80
- nn.GELU()
81
- )
82
- self.stages = nn.ModuleList()
83
- for i in range(len(dims)-1):
84
- stage = nn.Sequential(*[ConvBlock(dims[i],dims[i]) for _ in range(3)],
85
- nn.Conv2d(dims[i], dims[i+1], 2,2))
86
- self.stages.append(stage)
87
- self.stages.append(nn.Sequential(*[ConvBlock(dims[-1],dims[-1]) for _ in range(3)]))
88
- self.norm = nn.LayerNorm(dims[-1])
89
- def forward(self, x):
90
- x = self.stem(x)
91
- for stage in self.stages:
92
- x = stage(x)
93
- x = x.flatten(2).transpose(1,2)
94
- x = self.norm(x)
95
- return x
96
 
97
- class LocalGraphProp(nn.Module):
98
- def __init__(self, dim, steps=3):
 
 
 
 
99
  super().__init__()
100
- self.steps = steps
101
- self.W_self = nn.Parameter(torch.tensor(1.0))
102
- self.W_neigh = nn.Parameter(torch.ones(8)/8)
103
- self.update = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Linear(dim, dim))
104
- def forward(self,x):
105
- B,L,D = x.shape
106
- H = W = int(math.sqrt(L))
107
- grid = x.view(B,H,W,D)
108
- for _ in range(self.steps):
109
- padded = F.pad(grid,(0,0,1,1,1,1), mode='replicate')
110
- neighbors=[]
111
- for dy in [-1,0,1]:
112
- for dx in [-1,0,1]:
113
- if dy==dx==0: continue
114
- neighbors.append(padded[:,1+dy:H+1+dy,1+dx:W+1+dx])
115
- neigh = torch.stack(neighbors, dim=3)
116
- agg = self.W_self*grid.unsqueeze(3) + self.W_neigh.view(1,1,1,8,1)*neigh
117
- agg = agg.sum(dim=3)
118
- upd = self.update(agg.view(-1,D)).view(B,H,W,D)
119
- grid = grid + upd
120
- grid = torch.tanh(grid)*0.5 + 0.5
121
- return grid.view(B,L,D)
122
 
123
- class RadiologyCaptioner(nn.Module):
124
- def __init__(self,d_model=384,nhead=6,num_layers=3):
125
- super().__init__()
126
- self.encoder = CNNEncoder(dims=[96,192,d_model])
127
- self.graph_prop = LocalGraphProp(d_model)
128
- decoder_layer = nn.TransformerDecoderLayer(
129
- d_model=d_model, nhead=nhead, dim_feedforward=4*d_model,
130
- dropout=0.1, activation='gelu', batch_first=True
131
- )
132
- self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
133
- self.embed = nn.Embedding(VOCAB_SIZE,d_model)
134
- self.pos_embed = nn.Parameter(torch.zeros(1,MAX_SEQ_LEN,d_model))
135
- self.head = nn.Linear(d_model,VOCAB_SIZE)
136
- self.d_model=d_model
137
- def encode_image(self,images):
138
- feats=self.encoder(images)
139
- feats=self.graph_prop(feats)
140
- return feats
141
- def forward(self,images,input_ids,labels=None):
142
- memory=self.encode_image(images)
143
- tgt=self.embed(input_ids)*math.sqrt(self.d_model)+self.pos_embed[:,:input_ids.shape[1]]
144
- out=self.decoder(tgt,memory)
145
- logits=self.head(out)
146
- if labels is not None:
147
- loss=F.cross_entropy(logits.reshape(-1,VOCAB_SIZE), labels.reshape(-1), ignore_index=-100)
148
- return logits, loss
149
- return logits
150
- @torch.no_grad()
151
- def generate(self, images, max_len=MAX_SEQ_LEN, temperature=0.8):
152
- B = images.shape[0]
153
- memory = self.encode_image(images)
154
- tokens = torch.full((B,1), tokenizer.bos_token_id or tokenizer.eos_token_id, dtype=torch.long, device=DEVICE)
155
- for _ in range(max_len):
156
- tgt = self.embed(tokens)*math.sqrt(self.d_model)+self.pos_embed[:,:tokens.shape[1]]
157
- logits = self.head(self.decoder(tgt,memory)[:,-1])
158
- next_token = (logits/temperature).softmax(-1).multinomial(1)
159
- tokens = torch.cat([tokens,next_token],dim=1)
160
- if next_token.item() == tokenizer.eos_token_id:
161
- break
162
- return tokens
163
 
164
- # ======================
165
- # Instantiate model & load checkpoint
166
- # ======================
167
- model = RadiologyCaptioner().to(DEVICE, dtype=DTYPE)
168
- checkpoint_path = hf_hub_download(
169
- repo_id="erfanasghariyan/RADIOCAP200",
170
- filename="model.pt",
171
- subfolder="checkpoints/epoch_04"
172
- )
173
- state_dict = torch.load(checkpoint_path, map_location=DEVICE)
174
- model.load_state_dict(state_dict)
175
 
176
- # ======================
177
- # Image transform
178
- # ======================
179
- IMG_SIZE = 224
180
  transform = transforms.Compose([
181
- transforms.Resize((IMG_SIZE,IMG_SIZE)),
182
  transforms.ToTensor(),
183
- transforms.Normalize(mean=[0.5]*3,std=[0.5]*3)
 
184
  ])
185
 
186
- def load_image(img):
187
- if isinstance(img,np.ndarray):
188
- img = Image.fromarray(img)
189
- elif isinstance(img,str) and img.lower().endswith(".dcm"):
190
- dcm = pydicom.dcmread(img)
191
- arr = dcm.pixel_array.astype(np.float32)
192
- arr = np.clip((arr-arr.min())/(arr.ptp()+1e-6),0,1)
193
- img = Image.fromarray((arr*255).astype(np.uint8)).convert("RGB")
194
  return transform(img).unsqueeze(0).to(DEVICE, dtype=DTYPE)
195
 
196
- # ======================
197
- # Gradio interface
198
- # ======================
199
- def predict(img):
200
  img_tensor = load_image(img)
201
- tokens = model.generate(img_tensor)
202
- caption = tokenizer.decode(tokens[0], skip_special_tokens=True)
203
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
- iface = gr.Interface(fn=predict, inputs=gr.Image(type="filepath"), outputs="text", title="RADIOCAP200 Radiology Captioner")
206
- iface.launch()
 
 
 
 
 
 
1
  import torch
2
+ from torch import nn
 
3
  from torchvision import transforms
4
  from PIL import Image
 
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from transformers import AutoTokenizer
 
 
 
 
7
 
8
+ # ===========================
9
+ # تنظیمات دستگاه و dtype
10
+ # ===========================
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
13
 
14
+ # ===========================
15
+ # مسیر مدل و tokenizer
16
+ # ===========================
17
+ CHECKPOINT_PATH = "checkpoints/epoch_04/model.pt" # مسیر دانلود شده در Space
18
+ TOKENIZER_NAME = "bert-base-uncased" # یا مدل tokenizer مناسب شما
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # ===========================
23
+ # تعریف مدل (مثال ساده)
24
+ # ===========================
25
+ # توجه: مدل واقعی خودت را اینجا قرار بده
26
+ class DummyCaptionModel(nn.Module):
27
+ def __init__(self):
28
  super().__init__()
29
+ self.dummy = nn.Linear(10, 10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def forward(self, x, question=None):
32
+ # خروجی فرضی
33
+ if question:
34
+ return "Answer to question: " + question
35
+ return "Generated caption for the image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ model = DummyCaptionModel()
38
+ if torch.cuda.is_available():
39
+ model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
40
+ model.to(DEVICE)
41
+ model.eval()
 
 
 
 
 
 
42
 
43
+ # ===========================
44
+ # Transform تصویر
45
+ # ===========================
 
46
  transform = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
  transforms.ToTensor(),
49
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],
50
+ # std=[0.229, 0.224, 0.225])
51
  ])
52
 
53
+ # ===========================
54
+ # تابع بارگذاری تصویر
55
+ # ===========================
56
+ def load_image(img: Image.Image):
57
+ """تبدیل PIL image به Tensor"""
 
 
 
58
  return transform(img).unsqueeze(0).to(DEVICE, dtype=DTYPE)
59
 
60
+ # ===========================
61
+ # تابع اصلی پیش‌بینی
62
+ # ===========================
63
+ def predict(img: Image.Image, question: str = ""):
64
  img_tensor = load_image(img)
65
+ # اگر سوال خالی بود کپشن تولید کن، وگرنه VQA
66
+ output_text = model(img_tensor, question.strip() or None)
67
+ return output_text
68
+
69
+ # ===========================
70
+ # Interface گریدیو
71
+ # ===========================
72
+ iface = gr.Interface(
73
+ fn=predict,
74
+ inputs=[
75
+ gr.Image(type="pil", label="Upload Radiology Image"),
76
+ gr.Textbox(label="Optional Question (for VQA)", placeholder="Ask a question or leave empty for caption")
77
+ ],
78
+ outputs=gr.Textbox(label="Output"),
79
+ title="RADIOCAP200: Radiology Caption + VQA",
80
+ description="Upload a radiology image and optionally ask a question. If the question is empty, model generates a caption. Otherwise, it answers the question."
81
+ )
82
 
83
+ if __name__ == "__main__":
84
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=True)