CanerDedeoglu commited on
Commit
0b38f8d
·
verified ·
1 Parent(s): c38cd21

Rename handler.bak.py to mm_utils_local.py

Browse files
Files changed (2) hide show
  1. handler.bak.py +0 -540
  2. mm_utils_local.py +259 -0
handler.bak.py DELETED
@@ -1,540 +0,0 @@
1
- """
2
- PULSE-7B Enhanced Handler
3
- Ubden® Team - Edited by https://github.com/ck-cankurt
4
- Support: Text, Image URLs, and Base64 encoded images
5
- """
6
-
7
- import torch
8
- from typing import Dict, List, Any
9
- import base64
10
- from io import BytesIO
11
- from PIL import Image
12
- import requests
13
- import time
14
-
15
- # Import utilities if available
16
- try:
17
- from utils import (
18
- performance_monitor,
19
- validate_image_input,
20
- sanitize_parameters,
21
- get_system_info,
22
- create_health_check,
23
- deepseek_client
24
- )
25
- UTILS_AVAILABLE = True
26
- except ImportError:
27
- UTILS_AVAILABLE = False
28
- deepseek_client = None
29
- print("⚠️ Utils module not found - performance monitoring and DeepSeek integration disabled")
30
-
31
- # Try to import LLaVA modules for proper conversation handling
32
- try:
33
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
34
- from llava.conversation import conv_templates, SeparatorStyle
35
- from llava.mm_utils import tokenizer_image_token, process_images, KeywordsStoppingCriteria
36
- LLAVA_AVAILABLE = True
37
- print("✅ LLaVA modules imported successfully")
38
- except ImportError:
39
- LLAVA_AVAILABLE = False
40
- print("⚠️ LLaVA modules not available - using basic text processing")
41
-
42
-
43
- class EndpointHandler:
44
- def __init__(self, path=""):
45
- """
46
- Hey there! Let's get this PULSE-7B model up and running.
47
- We'll load it from the HuggingFace hub directly, so no worries about local files.
48
-
49
- Args:
50
- path: Model directory path (we actually ignore this and load from HF hub)
51
- """
52
- print("🚀 Starting up PULSE-7B handler...")
53
- print("📝 Enhanced by Ubden® Team - github.com/ck-cankurt")
54
- import sys
55
- print(f"🔧 Python version: {sys.version}")
56
- print(f"🔧 PyTorch version: {torch.__version__}")
57
-
58
- # Check transformers version
59
- try:
60
- import transformers
61
- print(f"🔧 Transformers version: {transformers.__version__}")
62
-
63
- # PULSE LLaVA works with transformers==4.37.2
64
- if transformers.__version__ == "4.37.2":
65
- print("✅ Using PULSE LLaVA compatible version (4.37.2)")
66
- elif "dev" in transformers.__version__ or "git" in str(transformers.__version__):
67
- print("⚠️ Using development version - may conflict with PULSE LLaVA")
68
- else:
69
- print("⚠️ Using different version - PULSE LLaVA prefers 4.37.2")
70
- except Exception as e:
71
- print(f"❌ Error checking transformers version: {e}")
72
-
73
- print(f"🔧 CUDA available: {torch.cuda.is_available()}")
74
- if torch.cuda.is_available():
75
- print(f"🔧 CUDA device: {torch.cuda.get_device_name(0)}")
76
-
77
- # Let's see what hardware we're working with
78
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
- print(f"🖥️ Running on: {self.device}")
80
-
81
- try:
82
- # First attempt - PULSE demo's exact approach
83
- if LLAVA_AVAILABLE:
84
- print("📦 Using PULSE demo's load_pretrained_model approach...")
85
- from llava.model.builder import load_pretrained_model
86
- from llava.mm_utils import get_model_name_from_path
87
-
88
- model_path = "PULSE-ECG/PULSE-7B"
89
- model_name = get_model_name_from_path(model_path)
90
-
91
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
92
- model_path=model_path,
93
- model_base=None,
94
- model_name=model_name,
95
- load_8bit=False,
96
- load_4bit=False
97
- )
98
-
99
- # Move model to device like demo
100
- self.model = self.model.to(self.device)
101
- self.use_pipeline = False
102
- print("✅ Model loaded successfully with PULSE demo's approach!")
103
- print(f"📸 Image processor: {type(self.image_processor).__name__}")
104
-
105
- else:
106
- raise ImportError("LLaVA modules not available")
107
-
108
- except Exception as e:
109
- print(f"⚠️ PULSE demo approach failed: {e}")
110
- print("🔄 Falling back to pipeline...")
111
-
112
- try:
113
- # Fallback - using pipeline
114
- from transformers import pipeline
115
-
116
- print("📦 Fetching model from HuggingFace Hub...")
117
- self.pipe = pipeline(
118
- "text-generation",
119
- model="PULSE-ECG/PULSE-7B",
120
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
121
- device=0 if torch.cuda.is_available() else -1,
122
- trust_remote_code=True,
123
- model_kwargs={
124
- "low_cpu_mem_usage": True,
125
- "use_safetensors": True
126
- }
127
- )
128
- self.use_pipeline = True
129
- self.image_processor = None
130
- print("✅ Model loaded successfully via pipeline!")
131
-
132
- except Exception as e2:
133
- print(f"😓 Pipeline also failed: {e2}")
134
-
135
- try:
136
- # Last resort - manual loading
137
- from transformers import AutoTokenizer, LlamaForCausalLM
138
-
139
- print("📖 Manual loading as last resort...")
140
- self.tokenizer = AutoTokenizer.from_pretrained(
141
- "PULSE-ECG/PULSE-7B",
142
- trust_remote_code=True
143
- )
144
-
145
- self.model = LlamaForCausalLM.from_pretrained(
146
- "PULSE-ECG/PULSE-7B",
147
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
148
- device_map="auto",
149
- low_cpu_mem_usage=True,
150
- trust_remote_code=True
151
- )
152
-
153
- if self.tokenizer.pad_token is None:
154
- self.tokenizer.pad_token = self.tokenizer.eos_token
155
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
156
-
157
- self.model.eval()
158
- self.use_pipeline = False
159
- self.image_processor = None
160
- print("✅ Model loaded manually!")
161
-
162
- except Exception as e3:
163
- print(f"😓 All approaches failed: {e3}")
164
- self.pipe = None
165
- self.model = None
166
- self.tokenizer = None
167
- self.image_processor = None
168
- self.use_pipeline = None
169
-
170
- # Final status report
171
- print("\n🔍 Model Loading Status Report:")
172
- print(f" - use_pipeline: {self.use_pipeline}")
173
- print(f" - model: {'✅ Loaded' if hasattr(self, 'model') and self.model is not None else '❌ None'}")
174
- print(f" - tokenizer: {'✅ Loaded' if hasattr(self, 'tokenizer') and self.tokenizer is not None else '❌ None'}")
175
- print(f" - image_processor: {'✅ Loaded' if hasattr(self, 'image_processor') and self.image_processor is not None else '❌ None'}")
176
- print(f" - pipe: {'✅ Loaded' if hasattr(self, 'pipe') and self.pipe is not None else '❌ None'}")
177
-
178
- # Check if any model component loaded successfully
179
- has_model = hasattr(self, 'model') and self.model is not None
180
- has_tokenizer = hasattr(self, 'tokenizer') and self.tokenizer is not None
181
- has_pipe = hasattr(self, 'pipe') and self.pipe is not None
182
- has_image_processor = hasattr(self, 'image_processor') and self.image_processor is not None
183
-
184
- if not (has_model or has_tokenizer or has_pipe):
185
- print("💥 CRITICAL: No model components loaded successfully!")
186
- else:
187
- print("✅ At least one model component loaded successfully")
188
- if has_image_processor:
189
- print("🖼️ Vision capabilities available!")
190
- else:
191
- print("⚠️ No image processor - text-only mode")
192
-
193
- def is_valid_image_format(self, filename_or_url):
194
- """Validate image format like PULSE demo"""
195
- # Demo's supported formats
196
- image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
197
-
198
- if filename_or_url.startswith(('http://', 'https://')):
199
- # For URLs, check the extension or content-type
200
- ext = filename_or_url.split('.')[-1].split('?')[0].lower()
201
- return ext in image_extensions
202
- else:
203
- # For base64 or local files
204
- return True # Base64 will be validated during decode
205
-
206
- def process_image_input(self, image_input):
207
- """
208
- Handle both URL and base64 image inputs exactly like PULSE demo
209
-
210
- Args:
211
- image_input: Can be a URL string or base64 encoded image
212
-
213
- Returns:
214
- PIL Image object or None if something goes wrong
215
- """
216
- try:
217
- # Check if it's a URL (starts with http/https)
218
- if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
219
- print(f"🌐 Fetching image from URL: {image_input[:50]}...")
220
-
221
- # Validate format like demo
222
- if not self.is_valid_image_format(image_input):
223
- print("❌ Invalid image format in URL")
224
- return None
225
-
226
- # Demo's exact image loading approach
227
- response = requests.get(image_input, timeout=15)
228
- if response.status_code == 200:
229
- image = Image.open(BytesIO(response.content)).convert("RGB")
230
- print(f"✅ Image downloaded successfully! Size: {image.size}")
231
- return image
232
- else:
233
- print(f"❌ Failed to load image: status {response.status_code}")
234
- return None
235
-
236
- # Must be base64 then
237
- elif isinstance(image_input, str):
238
- print("🔍 Decoding base64 image...")
239
-
240
- # Remove the data URL prefix if it exists
241
- base64_data = image_input
242
- if "base64," in image_input:
243
- base64_data = image_input.split("base64,")[1]
244
-
245
- # Clean and validate base64 data
246
- base64_data = base64_data.strip().replace('\n', '').replace('\r', '').replace(' ', '')
247
-
248
- try:
249
- image_data = base64.b64decode(base64_data)
250
- image = Image.open(BytesIO(image_data)).convert('RGB')
251
- print(f"✅ Base64 image decoded successfully! Size: {image.size}")
252
- return image
253
- except Exception as decode_error:
254
- print(f"❌ Base64 decode error: {decode_error}")
255
- return None
256
-
257
- except Exception as e:
258
- print(f"❌ Couldn't process the image: {e}")
259
- return None
260
-
261
- return None
262
-
263
- def add_turkish_commentary(self, response: Dict[str, Any], enable_commentary: bool, timeout: int = 30) -> Dict[str, Any]:
264
- """Add Turkish commentary to the response using DeepSeek API"""
265
- if not enable_commentary:
266
- return response
267
-
268
- if not UTILS_AVAILABLE or not deepseek_client:
269
- print("⚠️ DeepSeek client not available - skipping Turkish commentary")
270
- response["commentary_status"] = "unavailable"
271
- return response
272
-
273
- if not deepseek_client.is_available():
274
- print("⚠️ DeepSeek API key not configured - skipping Turkish commentary")
275
- response["commentary_status"] = "api_key_missing"
276
- return response
277
-
278
- generated_text = response.get("generated_text", "")
279
- if not generated_text:
280
- print("⚠️ No generated text to comment on")
281
- response["commentary_status"] = "no_text"
282
- return response
283
-
284
- print("🔄 DeepSeek ile Türkçe yorum ekleniyor...")
285
- commentary_result = deepseek_client.get_turkish_commentary(generated_text, timeout)
286
-
287
- if commentary_result["success"]:
288
- response["comment_text"] = commentary_result["comment_text"]
289
- response["commentary_model"] = commentary_result.get("model", "deepseek-chat")
290
- response["commentary_tokens"] = commentary_result.get("tokens_used", 0)
291
- response["commentary_status"] = "success"
292
- print("✅ Türkçe yorum başarıyla eklendi")
293
- else:
294
- response["comment_text"] = ""
295
- response["commentary_error"] = commentary_result["error"]
296
- response["commentary_status"] = "failed"
297
- print(f"❌ Türkçe yorum eklenemedi: {commentary_result['error']}")
298
-
299
- return response
300
-
301
- def health_check(self) -> Dict[str, Any]:
302
- """Health check endpoint"""
303
- if UTILS_AVAILABLE:
304
- return create_health_check()
305
- else:
306
- return {
307
- 'status': 'healthy',
308
- 'model': 'PULSE-7B',
309
- 'timestamp': time.time(),
310
- 'handler_version': '2.0.0'
311
- }
312
-
313
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
314
- """
315
- Main processing function - where the magic happens!
316
-
317
- Args:
318
- data: Input data with 'inputs' and optional 'parameters'
319
-
320
- Returns:
321
- List with the generated response
322
- """
323
- # Quick check - is our model ready?
324
- if self.use_pipeline is None:
325
- return [{
326
- "generated_text": "Oops! Model couldn't load properly. Please check the deployment settings.",
327
- "error": "Model initialization failed",
328
- "handler": "Ubden® Team Enhanced Handler"
329
- }]
330
-
331
- try:
332
- # Parse the inputs - flexible format support
333
- inputs = data.get("inputs", "")
334
- text = ""
335
- image = None
336
-
337
- if isinstance(inputs, dict):
338
- # Dictionary input - check for text and image
339
- # Support query field (new) plus original text/prompt fields
340
- text = inputs.get("query", inputs.get("text", inputs.get("prompt", str(inputs))))
341
-
342
- # Check for image in various formats
343
- image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
344
- if image_input:
345
- image = self.process_image_input(image_input)
346
- if image:
347
- # Since we're in text-only mode, create smart ECG context
348
- print(f"🖼️ Image loaded: {image.size[0]}x{image.size[1]} pixels - using text-only ECG analysis mode")
349
-
350
- # Create ECG-specific prompt that mimics visual analysis
351
- ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
352
-
353
- # Use demo's exact approach - no additional context, just the query
354
- # Model is trained to understand ECG images from text queries
355
- pass # Keep text exactly as received
356
- else:
357
- # Simple string input
358
- text = str(inputs)
359
-
360
- if not text:
361
- return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
362
-
363
- # Get generation parameters - using PULSE-7B demo's exact settings
364
- parameters = data.get("parameters", {})
365
- max_new_tokens = min(parameters.get("max_new_tokens", 1024), 8192) # Demo uses 1024 default
366
- temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
367
- top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
368
- do_sample = parameters.get("do_sample", True) # Demo uses sampling
369
- repetition_penalty = parameters.get("repetition_penalty", 1.0) # Demo default
370
-
371
- print(f"🎛️ Generation params: max_tokens={max_new_tokens}, temp={temperature}, top_p={top_p}, do_sample={do_sample}, rep_penalty={repetition_penalty}")
372
-
373
- # Check if Turkish commentary is requested (NEW FEATURE)
374
- enable_turkish_commentary = parameters.get("enable_turkish_commentary", False) # Default false
375
-
376
- # Using pipeline? Let's go!
377
- if self.use_pipeline:
378
- print(f"🎛️ Pipeline generation: temp={temperature}, tokens={max_new_tokens}")
379
- print(f"📝 Input text: '{text[:100]}...'")
380
-
381
- result = self.pipe(
382
- text,
383
- max_new_tokens=max_new_tokens,
384
- min_new_tokens=200, # Force very detailed analysis to match demo
385
- temperature=temperature,
386
- top_p=top_p,
387
- do_sample=do_sample,
388
- repetition_penalty=repetition_penalty,
389
- return_full_text=False # Just the new stuff, not the input
390
- )
391
-
392
- # Pipeline returns a list, let's handle it
393
- if isinstance(result, list) and len(result) > 0:
394
- generated_text = result[0].get("generated_text", "").strip()
395
-
396
- print(f"🔍 Pipeline debug:")
397
- print(f" - Raw result: '{str(result[0])[:200]}...'")
398
- print(f" - Generated text length: {len(generated_text)}")
399
-
400
- # Clean up common issues
401
- if generated_text.startswith(text):
402
- generated_text = generated_text[len(text):].strip()
403
- print("🔧 Removed input text from output")
404
-
405
- # Remove common artifacts
406
- generated_text = generated_text.replace("</s>", "").strip()
407
-
408
- if not generated_text:
409
- print("❌ Pipeline generated empty text!")
410
- generated_text = "Empty response from pipeline. Please try different parameters."
411
-
412
- print(f"✅ Final pipeline text: '{generated_text[:100]}...' (length: {len(generated_text)})")
413
-
414
- # Create response
415
- response = {"generated_text": generated_text}
416
-
417
- # Add Turkish commentary if requested (NEW FEATURE)
418
- if enable_turkish_commentary:
419
- response = self.add_turkish_commentary(response, True)
420
-
421
- return [response]
422
- else:
423
- generated_text = str(result).strip()
424
-
425
- # Create response
426
- response = {"generated_text": generated_text}
427
-
428
- # Add Turkish commentary if requested (NEW FEATURE)
429
- if enable_turkish_commentary:
430
- response = self.add_turkish_commentary(response, True)
431
-
432
- return [response]
433
-
434
- # Manual generation mode - using PULSE demo's exact approach
435
- else:
436
- print(f"🔥 Manual generation with PULSE demo logic: temp={temperature}, tokens={max_new_tokens}")
437
- print(f"📝 Input text: '{text[:100]}...'")
438
-
439
- # Text-only generation with enhanced ECG context
440
- print("🔤 Using enhanced text-only generation with ECG context")
441
-
442
- # Tokenize the enhanced prompt
443
- encoded = self.tokenizer(
444
- text,
445
- return_tensors="pt",
446
- truncation=True,
447
- max_length=4096 # Increased for longer prompts
448
- )
449
-
450
- input_ids = encoded["input_ids"].to(self.device)
451
- attention_mask = encoded.get("attention_mask")
452
- if attention_mask is not None:
453
- attention_mask = attention_mask.to(self.device)
454
-
455
- print(f"🔍 Enhanced generation debug:")
456
- print(f" - Enhanced prompt length: {len(text)} chars")
457
- print(f" - Input tokens: {input_ids.shape[-1]}")
458
- print(f" - Prompt preview: '{text[:150]}...'")
459
-
460
- # Generate with enhanced settings for medical analysis
461
- with torch.no_grad():
462
- outputs = self.model.generate(
463
- input_ids,
464
- attention_mask=attention_mask,
465
- max_new_tokens=max_new_tokens,
466
- min_new_tokens=200, # Force detailed response like demo
467
- temperature=temperature,
468
- top_p=top_p,
469
- do_sample=do_sample,
470
- repetition_penalty=repetition_penalty,
471
- pad_token_id=self.tokenizer.pad_token_id,
472
- eos_token_id=self.tokenizer.eos_token_id,
473
- early_stopping=False
474
- )
475
-
476
- # Decode and clean response
477
- generated_ids = outputs[0][input_ids.shape[-1]:]
478
- generated_text = self.tokenizer.decode(
479
- generated_ids,
480
- skip_special_tokens=True,
481
- clean_up_tokenization_spaces=True
482
- ).strip()
483
-
484
- # Aggressive cleanup of artifacts
485
- generated_text = generated_text.replace("</s>", "").strip()
486
-
487
- # Simple cleanup - just remove Answer prefix and parentheses
488
- if generated_text.startswith("(Answer:") and ")" in generated_text:
489
- # Just remove the parentheses and Answer: prefix
490
- end_paren = generated_text.find(")")
491
- answer_content = generated_text[8:end_paren].strip() # Remove "(Answer:"
492
- # Keep the rest of the response if there is any
493
- rest_of_response = generated_text[end_paren+1:].strip()
494
-
495
- if rest_of_response:
496
- generated_text = f"{answer_content}. {rest_of_response}"
497
- else:
498
- generated_text = answer_content
499
-
500
- elif generated_text.startswith("Answer:"):
501
- generated_text = generated_text[7:].strip()
502
-
503
- # Remove only clear training artifacts
504
- cleanup_patterns = [
505
- "In this task",
506
- "I'm asking the respondent",
507
- "The respondent should"
508
- ]
509
-
510
- for pattern in cleanup_patterns:
511
- if pattern in generated_text:
512
- parts = generated_text.split(pattern)
513
- generated_text = parts[0].strip()
514
- break
515
-
516
- # Only provide fallback if response is truly empty or malformed
517
- if len(generated_text) < 10 or generated_text.startswith("7)"):
518
- print("⚠️ Malformed response detected, providing fallback...")
519
- generated_text = "This ECG shows cardiac electrical activity. For accurate interpretation, please consult with a qualified cardiologist who can analyze the specific waveforms, intervals, and morphology patterns."
520
-
521
- print(f"✅ Enhanced text-only generation: '{generated_text[:100]}...' (length: {len(generated_text)})")
522
-
523
- # Create response
524
- response = {"generated_text": generated_text}
525
-
526
- # Add Turkish commentary if requested (NEW FEATURE)
527
- if enable_turkish_commentary:
528
- response = self.add_turkish_commentary(response, True)
529
-
530
- return [response]
531
-
532
-
533
- except Exception as e:
534
- error_msg = f"Something went wrong during generation: {str(e)}"
535
- print(f"❌ {error_msg}")
536
- return [{
537
- "generated_text": "",
538
- "error": error_msg,
539
- "handler": "Ubden® Team Enhanced Handler"
540
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mm_utils_local.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mm_utils_local.py
2
+ # LLaVA/PULSE uyumlu, dayanıklı mm_utils (anyres + pad)
3
+ # - crop_size/size alanlarını güvenli okur
4
+ # - preprocess veya __call__ farkını soğurur
5
+ # - patch_size'a tam bölünecek pad ekler
6
+ # - upstream imzalarıyla uyumludur
7
+
8
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
9
+ from io import BytesIO
10
+ import base64
11
+ import math
12
+ import ast
13
+
14
+ import torch
15
+ from PIL import Image
16
+ from transformers import StoppingCriteria
17
+ from llava.constants import IMAGE_TOKEN_INDEX # imza uyumu için
18
+
19
+ # ---------- Yardımcılar ----------
20
+
21
+ def _get_crop_size(processor: Any, default: int = 224) -> int:
22
+ cs = getattr(processor, "crop_size", None)
23
+ if cs is None:
24
+ sz = getattr(processor, "size", None)
25
+ if isinstance(sz, dict):
26
+ return int(sz.get("shortest_edge", default))
27
+ if isinstance(sz, int):
28
+ return int(sz)
29
+ return int(default)
30
+ if isinstance(cs, dict):
31
+ if "height" in cs:
32
+ return int(cs["height"])
33
+ if "shortest_edge" in cs:
34
+ return int(cs["shortest_edge"])
35
+ # beklenmedik dict: ilk değeri al
36
+ for v in cs.values():
37
+ return int(v)
38
+ return int(cs)
39
+
40
+ def _get_shortest_edge(processor: Any, fallback: Optional[int] = None) -> int:
41
+ sz = getattr(processor, "size", None)
42
+ if isinstance(sz, dict) and "shortest_edge" in sz:
43
+ return int(sz["shortest_edge"])
44
+ if isinstance(sz, int):
45
+ return int(sz)
46
+ return _get_crop_size(processor, default=(fallback or 224))
47
+
48
+ def _preprocess_one(processor: Any, img: Image.Image) -> torch.Tensor:
49
+ # Bazı sürümlerde .preprocess yok; direkt __call__ çalıştırılır.
50
+ if hasattr(processor, "preprocess"):
51
+ out = processor.preprocess(img, return_tensors="pt")
52
+ else:
53
+ out = processor(img, return_tensors="pt")
54
+ return out["pixel_values"][0]
55
+
56
+ def pad_to_multiple(image: Image.Image, multiple: int) -> Image.Image:
57
+ w, h = image.size
58
+ W = math.ceil(w / multiple) * multiple
59
+ H = math.ceil(h / multiple) * multiple
60
+ if (W, H) == (w, h):
61
+ return image
62
+ canvas = Image.new(image.mode, (W, H), (0, 0, 0))
63
+ canvas.paste(image, (0, 0))
64
+ return canvas
65
+
66
+ # ---------- Orijinal API ----------
67
+
68
+ def select_best_resolution(original_size: Tuple[int, int], possible_resolutions: List[Tuple[int, int]]) -> Tuple[int, int]:
69
+ """Upstream ile aynı mantık: en etkili ve en az boşa giden çözünürlüğü seç."""
70
+ original_width, original_height = original_size
71
+ best_fit = None
72
+ max_effective_resolution = 0
73
+ min_wasted_resolution = float("inf")
74
+ for width, height in possible_resolutions:
75
+ scale = min(width / original_width, height / original_height)
76
+ down_w, down_h = int(original_width * scale), int(original_height * scale)
77
+ effective = min(down_w * down_h, original_width * original_height)
78
+ wasted = (width * height) - effective
79
+ if (effective > max_effective_resolution) or (effective == max_effective_resolution and wasted < min_wasted_resolution):
80
+ max_effective_resolution = effective
81
+ min_wasted_resolution = wasted
82
+ best_fit = (width, height)
83
+ return best_fit
84
+
85
+ def resize_and_pad_image(image: Image.Image, target_resolution: Tuple[int, int]) -> Image.Image:
86
+ """Hedef çözünürlüğe orantıyı koruyarak resize + siyah pad."""
87
+ ow, oh = image.size
88
+ W, H = target_resolution
89
+ sw, sh = W / ow, H / oh
90
+ if sw < sh:
91
+ nw, nh = W, min(math.ceil(oh * sw), H)
92
+ else:
93
+ nh, nw = H, min(math.ceil(ow * sh), W)
94
+ resized = image.resize((nw, nh))
95
+ canvas = Image.new("RGB", (W, H), (0, 0, 0))
96
+ canvas.paste(resized, ((W - nw) // 2, (H - nh) // 2))
97
+ return canvas
98
+
99
+ def divide_to_patches(image: Image.Image, patch_size: int) -> List[Image.Image]:
100
+ """Görüntüyü patch_size x patch_size karelere böl."""
101
+ patches: List[Image.Image] = []
102
+ W, H = image.size
103
+ for y in range(0, H, patch_size):
104
+ for x in range(0, W, patch_size):
105
+ patches.append(image.crop((x, y, x + patch_size, y + patch_size)))
106
+ return patches
107
+
108
+ def get_anyres_image_grid_shape(image_size: Tuple[int, int], grid_pinpoints, patch_size: int) -> Tuple[int, int]:
109
+ """AnyRes sonrası patch ızgara boyutu (W//patch, H//patch)."""
110
+ if isinstance(grid_pinpoints, list):
111
+ possible_resolutions = grid_pinpoints
112
+ else:
113
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
114
+ width, height = select_best_resolution(image_size, possible_resolutions)
115
+ return width // patch_size, height // patch_size
116
+
117
+ def process_anyres_image(image: Image.Image, processor: Any, grid_pinpoints) -> torch.Tensor:
118
+ """
119
+ Robust AnyRes:
120
+ - crop_size/size güvenli okuma
121
+ - hedef çözünürlüğe resize+pad
122
+ - patch_size'a tam bölünecek pad
123
+ - preprocess/call farkını soyutlama
124
+ """
125
+ if isinstance(grid_pinpoints, list):
126
+ possible_resolutions = grid_pinpoints
127
+ else:
128
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
129
+
130
+ patch_size = _get_crop_size(processor, default=224)
131
+ shortest_edge = _get_shortest_edge(processor, fallback=patch_size)
132
+
133
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
134
+ image_padded = resize_and_pad_image(image, best_resolution)
135
+ image_padded = pad_to_multiple(image_padded, patch_size)
136
+
137
+ patches = divide_to_patches(image_padded, patch_size)
138
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
139
+
140
+ image_patches = [_preprocess_one(processor, image_original_resize)]
141
+ image_patches += [_preprocess_one(processor, p) for p in patches]
142
+ return torch.stack(image_patches, dim=0)
143
+
144
+ def load_image_from_base64(image: str) -> Image.Image:
145
+ return Image.open(BytesIO(base64.b64decode(image)))
146
+
147
+ def expand2square(pil_img: Image.Image, background_color: Tuple[int, int, int]) -> Image.Image:
148
+ w, h = pil_img.size
149
+ if w == h:
150
+ return pil_img
151
+ if w > h:
152
+ result = Image.new(pil_img.mode, (w, w), background_color)
153
+ result.paste(pil_img, (0, (w - h) // 2))
154
+ return result
155
+ result = Image.new(pil_img.mode, (h, h), background_color)
156
+ result.paste(pil_img, ((h - w) // 2, 0))
157
+ return result
158
+
159
+ def process_images(images: List[Image.Image], image_processor: Any, model_cfg: Any):
160
+ """
161
+ Upstream API ile aynı isim/geri dönüş; ancak daha dayanıklı:
162
+ - pad: image_mean yoksa güvenli varsayılan (0.5,0.5,0.5)
163
+ - anyres: robust process_anyres_image
164
+ - else: toplu çağrı TypeError ise tek tek çağrı fallback
165
+ """
166
+ # bazı konfig’lerde alan adı mm_image_aspect_ratio olabilir
167
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) or getattr(model_cfg, "mm_image_aspect_ratio", None)
168
+ new_images: List[torch.Tensor] = []
169
+
170
+ if image_aspect_ratio == "pad":
171
+ for image in images:
172
+ img_mean = getattr(image_processor, "image_mean", [0.5, 0.5, 0.5])
173
+ bg = tuple(int(x * 255) for x in img_mean)
174
+ image_sq = expand2square(image, bg)
175
+ image_t = _preprocess_one(image_processor, image_sq)
176
+ new_images.append(image_t)
177
+
178
+ elif image_aspect_ratio == "anyres":
179
+ grid = getattr(model_cfg, "image_grid_pinpoints", "[(336,336)]")
180
+ for image in images:
181
+ image_t = process_anyres_image(image, image_processor, grid)
182
+ new_images.append(image_t)
183
+
184
+ else:
185
+ try:
186
+ out = image_processor(images, return_tensors="pt")
187
+ return out["pixel_values"]
188
+ except TypeError:
189
+ outs = [image_processor(img, return_tensors="pt") for img in images]
190
+ pix = [o["pixel_values"][0] for o in outs]
191
+ return torch.stack(pix, dim=0)
192
+
193
+ if all(x.shape == new_images[0].shape for x in new_images):
194
+ return torch.stack(new_images, dim=0)
195
+ return new_images
196
+
197
+ def tokenizer_image_token(prompt: str, tokenizer: Any, image_token_index: int = IMAGE_TOKEN_INDEX, return_tensors: Optional[str] = None):
198
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
199
+
200
+ def insert_separator(X, sep):
201
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
202
+
203
+ input_ids: List[int] = []
204
+ offset = 0
205
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
206
+ offset = 1
207
+ input_ids.append(prompt_chunks[0][0])
208
+
209
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
210
+ input_ids.extend(x[offset:])
211
+
212
+ if return_tensors is not None:
213
+ if return_tensors == "pt":
214
+ return torch.tensor(input_ids, dtype=torch.long)
215
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
216
+ return input_ids
217
+
218
+ def get_model_name_from_path(model_path: str) -> str:
219
+ model_path = model_path.strip("/")
220
+ model_paths = model_path.split("/")
221
+ if model_paths[-1].startswith("checkpoint-"):
222
+ return model_paths[-2] + "_" + model_paths[-1]
223
+ else:
224
+ return model_paths[-1]
225
+
226
+ # Upstream ile uyumlu: durdurma kriteri
227
+ class KeywordsStoppingCriteria(StoppingCriteria):
228
+ def __init__(self, keywords, tokenizer, input_ids):
229
+ self.keywords = keywords
230
+ self.keyword_ids = []
231
+ self.max_keyword_len = 0
232
+ for keyword in keywords:
233
+ cur_keyword_ids = tokenizer(keyword).input_ids
234
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
235
+ cur_keyword_ids = cur_keyword_ids[1:]
236
+ if len(cur_keyword_ids) > self.max_keyword_len:
237
+ self.max_keyword_len = len(cur_keyword_ids)
238
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
239
+ self.tokenizer = tokenizer
240
+ self.start_len = input_ids.shape[1]
241
+
242
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
243
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
244
+ self.keyword_ids = [kid.to(output_ids.device) for kid in self.keyword_ids]
245
+ for kid in self.keyword_ids:
246
+ truncated = output_ids[0, -kid.shape[0]:]
247
+ if torch.equal(truncated, kid):
248
+ return True
249
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
250
+ for keyword in self.keywords:
251
+ if keyword in outputs:
252
+ return True
253
+ return False
254
+
255
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
256
+ outs = []
257
+ for i in range(output_ids.shape[0]):
258
+ outs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
259
+ return all(outs)