CanerDedeoglu commited on
Commit
cec111a
·
verified ·
1 Parent(s): 18475c7

Delete mm_utils_local.py

Browse files
Files changed (1) hide show
  1. mm_utils_local.py +0 -259
mm_utils_local.py DELETED
@@ -1,259 +0,0 @@
1
- # mm_utils_local.py
2
- # LLaVA/PULSE uyumlu, dayanıklı mm_utils (anyres + pad)
3
- # - crop_size/size alanlarını güvenli okur
4
- # - preprocess veya __call__ farkını soğurur
5
- # - patch_size'a tam bölünecek pad ekler
6
- # - upstream imzalarıyla uyumludur
7
-
8
- from typing import Any, Dict, List, Optional, Sequence, Tuple
9
- from io import BytesIO
10
- import base64
11
- import math
12
- import ast
13
-
14
- import torch
15
- from PIL import Image
16
- from transformers import StoppingCriteria
17
- from llava.constants import IMAGE_TOKEN_INDEX # imza uyumu için
18
-
19
- # ---------- Yardımcılar ----------
20
-
21
- def _get_crop_size(processor: Any, default: int = 224) -> int:
22
- cs = getattr(processor, "crop_size", None)
23
- if cs is None:
24
- sz = getattr(processor, "size", None)
25
- if isinstance(sz, dict):
26
- return int(sz.get("shortest_edge", default))
27
- if isinstance(sz, int):
28
- return int(sz)
29
- return int(default)
30
- if isinstance(cs, dict):
31
- if "height" in cs:
32
- return int(cs["height"])
33
- if "shortest_edge" in cs:
34
- return int(cs["shortest_edge"])
35
- # beklenmedik dict: ilk değeri al
36
- for v in cs.values():
37
- return int(v)
38
- return int(cs)
39
-
40
- def _get_shortest_edge(processor: Any, fallback: Optional[int] = None) -> int:
41
- sz = getattr(processor, "size", None)
42
- if isinstance(sz, dict) and "shortest_edge" in sz:
43
- return int(sz["shortest_edge"])
44
- if isinstance(sz, int):
45
- return int(sz)
46
- return _get_crop_size(processor, default=(fallback or 224))
47
-
48
- def _preprocess_one(processor: Any, img: Image.Image) -> torch.Tensor:
49
- # Bazı sürümlerde .preprocess yok; direkt __call__ çalıştırılır.
50
- if hasattr(processor, "preprocess"):
51
- out = processor.preprocess(img, return_tensors="pt")
52
- else:
53
- out = processor(img, return_tensors="pt")
54
- return out["pixel_values"][0]
55
-
56
- def pad_to_multiple(image: Image.Image, multiple: int) -> Image.Image:
57
- w, h = image.size
58
- W = math.ceil(w / multiple) * multiple
59
- H = math.ceil(h / multiple) * multiple
60
- if (W, H) == (w, h):
61
- return image
62
- canvas = Image.new(image.mode, (W, H), (0, 0, 0))
63
- canvas.paste(image, (0, 0))
64
- return canvas
65
-
66
- # ---------- Orijinal API ----------
67
-
68
- def select_best_resolution(original_size: Tuple[int, int], possible_resolutions: List[Tuple[int, int]]) -> Tuple[int, int]:
69
- """Upstream ile aynı mantık: en etkili ve en az boşa giden çözünürlüğü seç."""
70
- original_width, original_height = original_size
71
- best_fit = None
72
- max_effective_resolution = 0
73
- min_wasted_resolution = float("inf")
74
- for width, height in possible_resolutions:
75
- scale = min(width / original_width, height / original_height)
76
- down_w, down_h = int(original_width * scale), int(original_height * scale)
77
- effective = min(down_w * down_h, original_width * original_height)
78
- wasted = (width * height) - effective
79
- if (effective > max_effective_resolution) or (effective == max_effective_resolution and wasted < min_wasted_resolution):
80
- max_effective_resolution = effective
81
- min_wasted_resolution = wasted
82
- best_fit = (width, height)
83
- return best_fit
84
-
85
- def resize_and_pad_image(image: Image.Image, target_resolution: Tuple[int, int]) -> Image.Image:
86
- """Hedef çözünürlüğe orantıyı koruyarak resize + siyah pad."""
87
- ow, oh = image.size
88
- W, H = target_resolution
89
- sw, sh = W / ow, H / oh
90
- if sw < sh:
91
- nw, nh = W, min(math.ceil(oh * sw), H)
92
- else:
93
- nh, nw = H, min(math.ceil(ow * sh), W)
94
- resized = image.resize((nw, nh))
95
- canvas = Image.new("RGB", (W, H), (0, 0, 0))
96
- canvas.paste(resized, ((W - nw) // 2, (H - nh) // 2))
97
- return canvas
98
-
99
- def divide_to_patches(image: Image.Image, patch_size: int) -> List[Image.Image]:
100
- """Görüntüyü patch_size x patch_size karelere böl."""
101
- patches: List[Image.Image] = []
102
- W, H = image.size
103
- for y in range(0, H, patch_size):
104
- for x in range(0, W, patch_size):
105
- patches.append(image.crop((x, y, x + patch_size, y + patch_size)))
106
- return patches
107
-
108
- def get_anyres_image_grid_shape(image_size: Tuple[int, int], grid_pinpoints, patch_size: int) -> Tuple[int, int]:
109
- """AnyRes sonrası patch ızgara boyutu (W//patch, H//patch)."""
110
- if isinstance(grid_pinpoints, list):
111
- possible_resolutions = grid_pinpoints
112
- else:
113
- possible_resolutions = ast.literal_eval(grid_pinpoints)
114
- width, height = select_best_resolution(image_size, possible_resolutions)
115
- return width // patch_size, height // patch_size
116
-
117
- def process_anyres_image(image: Image.Image, processor: Any, grid_pinpoints) -> torch.Tensor:
118
- """
119
- Robust AnyRes:
120
- - crop_size/size güvenli okuma
121
- - hedef çözünürlüğe resize+pad
122
- - patch_size'a tam bölünecek pad
123
- - preprocess/call farkını soyutlama
124
- """
125
- if isinstance(grid_pinpoints, list):
126
- possible_resolutions = grid_pinpoints
127
- else:
128
- possible_resolutions = ast.literal_eval(grid_pinpoints)
129
-
130
- patch_size = _get_crop_size(processor, default=224)
131
- shortest_edge = _get_shortest_edge(processor, fallback=patch_size)
132
-
133
- best_resolution = select_best_resolution(image.size, possible_resolutions)
134
- image_padded = resize_and_pad_image(image, best_resolution)
135
- image_padded = pad_to_multiple(image_padded, patch_size)
136
-
137
- patches = divide_to_patches(image_padded, patch_size)
138
- image_original_resize = image.resize((shortest_edge, shortest_edge))
139
-
140
- image_patches = [_preprocess_one(processor, image_original_resize)]
141
- image_patches += [_preprocess_one(processor, p) for p in patches]
142
- return torch.stack(image_patches, dim=0)
143
-
144
- def load_image_from_base64(image: str) -> Image.Image:
145
- return Image.open(BytesIO(base64.b64decode(image)))
146
-
147
- def expand2square(pil_img: Image.Image, background_color: Tuple[int, int, int]) -> Image.Image:
148
- w, h = pil_img.size
149
- if w == h:
150
- return pil_img
151
- if w > h:
152
- result = Image.new(pil_img.mode, (w, w), background_color)
153
- result.paste(pil_img, (0, (w - h) // 2))
154
- return result
155
- result = Image.new(pil_img.mode, (h, h), background_color)
156
- result.paste(pil_img, ((h - w) // 2, 0))
157
- return result
158
-
159
- def process_images(images: List[Image.Image], image_processor: Any, model_cfg: Any):
160
- """
161
- Upstream API ile aynı isim/geri dönüş; ancak daha dayanıklı:
162
- - pad: image_mean yoksa güvenli varsayılan (0.5,0.5,0.5)
163
- - anyres: robust process_anyres_image
164
- - else: toplu çağrı TypeError ise tek tek çağrı fallback
165
- """
166
- # bazı konfig’lerde alan adı mm_image_aspect_ratio olabilir
167
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) or getattr(model_cfg, "mm_image_aspect_ratio", None)
168
- new_images: List[torch.Tensor] = []
169
-
170
- if image_aspect_ratio == "pad":
171
- for image in images:
172
- img_mean = getattr(image_processor, "image_mean", [0.5, 0.5, 0.5])
173
- bg = tuple(int(x * 255) for x in img_mean)
174
- image_sq = expand2square(image, bg)
175
- image_t = _preprocess_one(image_processor, image_sq)
176
- new_images.append(image_t)
177
-
178
- elif image_aspect_ratio == "anyres":
179
- grid = getattr(model_cfg, "image_grid_pinpoints", "[(336,336)]")
180
- for image in images:
181
- image_t = process_anyres_image(image, image_processor, grid)
182
- new_images.append(image_t)
183
-
184
- else:
185
- try:
186
- out = image_processor(images, return_tensors="pt")
187
- return out["pixel_values"]
188
- except TypeError:
189
- outs = [image_processor(img, return_tensors="pt") for img in images]
190
- pix = [o["pixel_values"][0] for o in outs]
191
- return torch.stack(pix, dim=0)
192
-
193
- if all(x.shape == new_images[0].shape for x in new_images):
194
- return torch.stack(new_images, dim=0)
195
- return new_images
196
-
197
- def tokenizer_image_token(prompt: str, tokenizer: Any, image_token_index: int = IMAGE_TOKEN_INDEX, return_tensors: Optional[str] = None):
198
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
199
-
200
- def insert_separator(X, sep):
201
- return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
202
-
203
- input_ids: List[int] = []
204
- offset = 0
205
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
206
- offset = 1
207
- input_ids.append(prompt_chunks[0][0])
208
-
209
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
210
- input_ids.extend(x[offset:])
211
-
212
- if return_tensors is not None:
213
- if return_tensors == "pt":
214
- return torch.tensor(input_ids, dtype=torch.long)
215
- raise ValueError(f"Unsupported tensor type: {return_tensors}")
216
- return input_ids
217
-
218
- def get_model_name_from_path(model_path: str) -> str:
219
- model_path = model_path.strip("/")
220
- model_paths = model_path.split("/")
221
- if model_paths[-1].startswith("checkpoint-"):
222
- return model_paths[-2] + "_" + model_paths[-1]
223
- else:
224
- return model_paths[-1]
225
-
226
- # Upstream ile uyumlu: durdurma kriteri
227
- class KeywordsStoppingCriteria(StoppingCriteria):
228
- def __init__(self, keywords, tokenizer, input_ids):
229
- self.keywords = keywords
230
- self.keyword_ids = []
231
- self.max_keyword_len = 0
232
- for keyword in keywords:
233
- cur_keyword_ids = tokenizer(keyword).input_ids
234
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
235
- cur_keyword_ids = cur_keyword_ids[1:]
236
- if len(cur_keyword_ids) > self.max_keyword_len:
237
- self.max_keyword_len = len(cur_keyword_ids)
238
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
239
- self.tokenizer = tokenizer
240
- self.start_len = input_ids.shape[1]
241
-
242
- def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
243
- offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
244
- self.keyword_ids = [kid.to(output_ids.device) for kid in self.keyword_ids]
245
- for kid in self.keyword_ids:
246
- truncated = output_ids[0, -kid.shape[0]:]
247
- if torch.equal(truncated, kid):
248
- return True
249
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
250
- for keyword in self.keywords:
251
- if keyword in outputs:
252
- return True
253
- return False
254
-
255
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
256
- outs = []
257
- for i in range(output_ids.shape[0]):
258
- outs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
259
- return all(outs)