carlosh93's picture
updating new version with supabase and vlm
ed8368e
raw
history blame
17.3 kB
import torch, torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from transformers import TorchAoConfig, Qwen2_5_VLForConditionalGeneration, Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, AutoModel
from qwen_vl_utils import process_vision_info
import gc
# from transformers.image_utils import load_image
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
class VLMManager:
"""
A manager class for Vision-Language Models that handles model loading,
caching, and dynamic switching between different models.
"""
def __init__(self, default_model: str = "Gemma3-4B"):
"""
Initialize the VLM Manager with a default model.
Args:
default_model (str): The default model to load initially.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.current_model_name = None
self.processor = None
self.tokenizer = None # Initialize tokenizer attribute
self.model = None
self.system_message = """
You are an expert cultural-aware image-analysis assistant. For every image:
1. Output exactly 40 words in total.
2. Use a single paragraph (no lists or bullet points).
3. Describe Who (appearance/emotion), What (action), and Where (setting).
4. Do NOT include opinions or speculations.
5. If you go over 40 words, shorten or remove non-essential details.
"""
self.user_prompt = """
Given this image, please provide an image description of around 40 words with extensive and detailed visual information.
Descriptions must be objective: focus on how you would describe the image to someone who can't see it, without your own opinions/speculations.
The text needs to include the main concept and describe the content of the image in detail by including:
- Who?: The visual appearance and observable emotions (e.g., "is smiling") of persons and animals.
- What?: The actions performed in the image.
- Where?: The setting of the image, including the size, color, and relationships between objects.
"""
# Load the default model
self.load_model(default_model)
def load_model(self, model_name: str):
"""
Load a VLM model. If the model is already loaded, return the cached version.
Args:
model_name (str): The name of the model to load.
"""
# If the requested model is already loaded, no need to reload
if self.current_model_name == model_name and self.model is not None:
print(f"Model {model_name} is already loaded, using cached version.")
if self.current_model_name == "InternVL3_5-8B":
return self.tokenizer, self.model
else:
return self.processor, self.model
print(f"Loading model: {model_name}")
# Clear current model from memory if exists
if self.model is not None:
del self.model
self.model = None
if self.current_model_name == "InternVL3_5-8B":
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
else:
if hasattr(self, 'processor') and self.processor is not None:
del self.processor
self.processor = None
# Force garbage collection and clear CUDA cache
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Wait for all operations to complete
# Load the new model
if model_name == "SmolVLM-500M":
self.processor, self.model = self._load_smolvlm_model("HuggingFaceTB/SmolVLM-500M-Instruct")
elif model_name == "Qwen2.5-VL-7B":
self.processor, self.model = self._load_qwen25_model("Qwen/Qwen2.5-VL-7B-Instruct")
elif model_name == "InternVL3_5-8B":
self.tokenizer, self.model = self._load_internvl35_model("OpenGVLab/InternVL3_5-8B-Instruct")
elif model_name == "Gemma3-4B":
self.processor, self.model = self._load_gemma3_model("google/gemma-3-4b-it")
else:
raise ValueError(f"Model {model_name} is not supported or not available.")
self.current_model_name = model_name
print(f"Successfully loaded model: {model_name}")
def generate_caption(self, image):
"""
Generate a caption for the given image using the loaded model.
Args:
processor: The processor for the model.
model: The model to use for generating the caption.
image: The image to generate a caption for.
"""
if self.current_model_name == "SmolVLM-500M":
return self._inference_smolvlm_model(image)
elif self.current_model_name == "Qwen2.5-VL-7B":
return self._inference_qwen25_model(image)
elif self.current_model_name == "InternVL3_5-8B":
return self._inference_internvl35_model(image)
elif self.current_model_name == "Gemma3-4B":
return self._inference_gemma3_model(image)
else:
raise ValueError(f"Model {self.current_model_name} is not supported or not available.")
def get_current_model(self):
"""
Get the currently loaded model and processor.
Returns:
tuple: A tuple containing (processor, model, model_name).
"""
return self.processor, self.model, self.current_model_name
def cleanup_memory(self):
"""
Explicit memory cleanup method that can be called to free GPU memory.
"""
if self.model is not None:
del self.model
self.model = None
if hasattr(self, 'processor') and self.processor is not None:
del self.processor
self.processor = None
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
self.current_model_name = None
# Force cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("Memory cleanup completed.")
#########################################################
## Load functions
def _load_smolvlm_model(self, model_name):
"""Load SmolVLM model."""
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForVision2Seq.from_pretrained(
model_name,
_attn_implementation="eager"
).to(self.device)
model.eval()
return processor, model
def _load_qwen25_model(self, model_name):
"""Load Qwen2.5-VL model."""
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# "Qwen/Qwen2.5-VL-7B-Instruct",
# torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
processor = AutoProcessor.from_pretrained(model_name)
model.eval()
return processor, model
def _load_internvl35_model(self, model_name):
"""Load InternVL3.5 model."""
# Load tokenizer (InternVL uses tokenizer instead of processor for text)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Load the model using AutoModel
model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
low_cpu_mem_usage=True,
use_flash_attn=False, # True set False if CUDA mismatch
trust_remote_code=True,
device_map="auto"
)
model.eval()
# Return tokenizer as processor for consistency with the interface
return tokenizer, model
def _load_gemma3_model(self, model_name):
"""Load Gemma3 model."""
quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
model = Gemma3ForConditionalGeneration.from_pretrained(
model_name,
device_map="auto",
quantization_config=quantization_config
)
processor = AutoProcessor.from_pretrained(model_name)
model.eval()
return processor, model
#########################################################
## Inference functions
def check_processor_and_model(self):
if self.processor is None or self.model is None:
raise ValueError("Processor and model must be loaded before generating a caption.")
def _inference_qwen25_model(self, image):
"""Inference Qwen2.5-VL model."""
self.check_processor_and_model()
messages = [
{
"role": "system",
"content": [{"type": "text", "text": self.system_message}]
},
{
"role": "user",
"content": [
{
"type": "image",
"image": Image.fromarray(image),
},
{"type": "text", "text": self.user_prompt},
],
}
]
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
caption = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Clean up tensors to free GPU memory
del inputs, generated_ids, generated_ids_trimmed
if torch.cuda.is_available():
torch.cuda.empty_cache()
return caption
def _inference_gemma3_model(self, image):
"""Inference Gemma3 model."""
self.check_processor_and_model()
messages = [
{
"role": "system",
"content": [{"type": "text", "text": self.system_message}]
},
{
"role": "user",
"content": [
{"type": "image", "image": Image.fromarray(image)},
{"type": "text", "text": self.user_prompt}
]
}
]
inputs = self.processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
caption = self.processor.decode(generation, skip_special_tokens=True)
# Clean up tensors to free GPU memory
del inputs, generation
if torch.cuda.is_available():
torch.cuda.empty_cache()
return caption
def _inference_smolvlm_model(self, image):
self.check_processor_and_model()
messages = [
{
"role": "system",
"content": self.system_message
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": self.user_prompt}
]
}
]
# Prepare inputs
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(self.model.device)
# Generate outputs
gen_kwargs = {
"max_new_tokens": 200, # plenty for ~40 words
# "early_stopping": True, # stop at first EOS
# "no_repeat_ngram_size": 3, # discourage loops
# "length_penalty": 0.8, # slightly favor brevity
# "eos_token_id": processor.tokenizer.eos_token_id,
# "pad_token_id": processor.tokenizer.eos_token_id,
}
generated_ids = self.model.generate(**inputs, **gen_kwargs) # max_new_tokens=500)
generated_texts = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)[0]
# Extract only what the assistant said
if "Assistant:" in generated_texts:
caption = generated_texts.split("Assistant:", 1)[1].strip()
else:
caption = generated_texts.strip()
# Clean up tensors to free GPU memory
del inputs, generated_ids
if torch.cuda.is_available():
torch.cuda.empty_cache()
return caption
def _inference_internvl35_model(self, image):
if self.tokenizer is None:
raise ValueError("Tokenizer must be loaded before generating a caption for InternVL3.5.")
# image can be numpy (H,W,3) or PIL.Image
if hasattr(image, "shape"): # numpy array
pil_image = Image.fromarray(image.astype("uint8"), mode="RGB")
else:
pil_image = image
pixel_values = self._image_to_pixel_values(pil_image, size=448, max_num=12)
pixel_values = pixel_values.to(dtype=torch.bfloat16, device=self.model.device)
# Format question with image token (matches official docs)
question = "<image>\n" + self.user_prompt
# Generation config matching official examples
gen_cfg = dict(
max_new_tokens=128,
do_sample=False,
temperature=0.0,
# Optional: add other parameters from docs
# top_p=0.9,
# repetition_penalty=1.1
)
# Use model's chat method (official approach)
response = self.model.chat(self.tokenizer, pixel_values, question, gen_cfg)
# Clean up tensors to free GPU memory
del pixel_values
if torch.cuda.is_available():
torch.cuda.empty_cache()
return response.strip()
def _image_to_pixel_values(self, img, size=448, max_num=12):
transform = self._build_transform(size)
tiles = self._dynamic_preprocess(img, image_size=size, max_num=max_num, use_thumbnail=True)
pixel_values = torch.stack([transform(t) for t in tiles])
return pixel_values
def _dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
# same logic as the model card: split into tiles based on aspect ratio
w, h = image.size
aspect = w / h
targets = sorted({(i, j) for n in range(min_num, max_num+1)
for i in range(1, n+1) for j in range(1, n+1)
if i*j <= max_num and i*j >= min_num},
key=lambda x: x[0]*x[1])
# pick closest ratio
best = min(targets, key=lambda r: abs(aspect - r[0]/r[1]))
tw, th = image_size * best[0], image_size * best[1]
resized = image.resize((tw, th))
tiles = []
for i in range(best[0] * best[1]):
box = ((i % (tw // image_size)) * image_size,
(i // (tw // image_size)) * image_size,
((i % (tw // image_size)) + 1) * image_size,
((i // (tw // image_size)) + 1) * image_size)
tiles.append(resized.crop(box))
if use_thumbnail and len(tiles) != 1:
tiles.append(image.resize((image_size, image_size)))
return tiles
def _build_transform(self, input_size=448):
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
# Global VLM Manager instance
vlm_manager = VLMManager()