anhtung1106's picture
Upload 4 files
6366e63 verified
"""
=============================================================================
utils.py - Module quản lý các mô hình Image Captioning
=============================================================================
File này chứa class ImageCaptioningModels để quản lý việc tải và sử dụng
3 mô hình sinh mô tả ảnh:
1. ViT-GPT2: Vision Transformer + GPT-2 (nlpconnect/vit-gpt2-image-captioning)
2. BLIP-Large: Salesforce BLIP (Salesforce/blip-image-captioning-large)
3. GIT: Microsoft Generative Image-to-text (microsoft/git-large-coco)
Các tính năng chính:
- Tải mô hình theo yêu cầu (lazy loading) để tiết kiệm bộ nhớ
- Hỗ trợ GPU (CUDA) nếu có sẵn
- Sinh mô tả với các tham số có thể điều chỉnh (temperature, top_k, top_p, ...)
- Giải phóng bộ nhớ khi không cần thiết
Tác giả: Đồ án 2 - 2024
=============================================================================
"""
# =============================================================================
# IMPORT CÁC THƯ VIỆN CẦN THIẾT
# =============================================================================
import torch # Thư viện deep learning chính
from PIL import Image # Xử lý hình ảnh
# Import các class từ thư viện Transformers của Hugging Face
from transformers import (
# Cho mô hình ViT-GPT2
VisionEncoderDecoderModel, # Kiến trúc encoder-decoder kết hợp vision và text
ViTImageProcessor, # Tiền xử lý ảnh cho Vision Transformer
AutoTokenizer, # Tokenizer tự động cho text
# Cho mô hình BLIP
BlipProcessor, # Tiền xử lý cho BLIP (cả ảnh và text)
BlipForConditionalGeneration, # Mô hình BLIP cho sinh caption
# Cho mô hình GIT
AutoProcessor, # Processor tự động
AutoModelForCausalLM # Mô hình ngôn ngữ tự động hồi quy
)
import warnings
warnings.filterwarnings('ignore') # Tắt các cảnh báo không cần thiết
# =============================================================================
# CLASS IMAGEСAPTIONINGMODELS - QUẢN LÝ CÁC MÔ HÌNH SINH MÔ TẢ ẢNH
# =============================================================================
class ImageCaptioningModels:
"""
Class quản lý việc tải và sử dụng các mô hình sinh mô tả ảnh.
Attributes:
models (dict): Dictionary lưu trữ các mô hình đã tải
processors (dict): Dictionary lưu trữ các processor/tokenizer tương ứng
device (torch.device): Thiết bị chạy mô hình (CPU hoặc GPU)
Ví dụ sử dụng:
>>> manager = ImageCaptioningModels()
>>> manager.load_vit_gpt2() # Tải mô hình ViT-GPT2
>>> caption = manager.predict_vit_gpt2(image) # Sinh caption
"""
def __init__(self):
"""
Khởi tạo ImageCaptioningModels.
- Tạo dictionary rỗng để lưu models và processors
- Tự động phát hiện và sử dụng GPU nếu có (CUDA)
- In ra thiết bị đang sử dụng
"""
self.models = {} # Lưu trữ các mô hình đã tải {tên: model}
self.processors = {} # Lưu trữ processor tương ứng {tên: processor}
# Kiểm tra GPU có sẵn không, nếu có thì dùng GPU, không thì dùng CPU
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# =========================================================================
# PHẦN 1: CÁC HÀM TẢI MÔ HÌNH (LOAD MODELS)
# =========================================================================
def load_vit_gpt2(self):
"""
Tải mô hình ViT-GPT2 từ Hugging Face Hub.
Mô hình ViT-GPT2 kết hợp:
- Vision Transformer (ViT): Mã hóa ảnh thành vector đặc trưng
- GPT-2: Sinh văn bản mô tả từ vector đặc trưng
Returns:
bool: True nếu tải thành công, False nếu có lỗi
Lưu ý:
- Chỉ tải nếu chưa tải trước đó (kiểm tra trong self.models)
- Mô hình được chuyển sang GPU nếu có sẵn
"""
try:
# Kiểm tra xem model đã được tải chưa để tránh tải lại
if 'vit_gpt2' not in self.models:
model_name = "nlpconnect/vit-gpt2-image-captioning"
# Tải mô hình VisionEncoderDecoder và chuyển sang device (GPU/CPU)
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(self.device)
# Tải feature extractor (tiền xử lý ảnh) và tokenizer (xử lý text)
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Lưu vào dictionary để sử dụng sau
self.models['vit_gpt2'] = model
self.processors['vit_gpt2'] = (feature_extractor, tokenizer) # Tuple gồm 2 thành phần
print("ViT-GPT2 model loaded successfully")
return True
except Exception as e:
print(f"Error loading ViT-GPT2: {e}")
return False
def load_blip_large(self):
"""
Tải mô hình BLIP-Large từ Salesforce.
BLIP (Bootstrapping Language-Image Pre-training) là mô hình
multimodal được huấn luyện trên dữ liệu lớn với kỹ thuật
bootstrapping để cải thiện chất lượng caption.
Returns:
bool: True nếu tải thành công, False nếu có lỗi
"""
try:
if 'blip_large' not in self.models:
model_name = "Salesforce/blip-image-captioning-large"
# BLIP chỉ cần 1 processor (xử lý cả ảnh và text)
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(model_name).to(self.device)
self.models['blip_large'] = model
self.processors['blip_large'] = processor
print("BLIP-Large model loaded successfully")
return True
except Exception as e:
print(f"Error loading BLIP-Large: {e}")
return False
def load_git(self):
"""
Tải mô hình Microsoft GIT (Generative Image-to-text Transformer).
GIT là mô hình của Microsoft với kiến trúc đơn giản nhưng hiệu quả,
được huấn luyện trên tập COCO dataset.
Returns:
bool: True nếu tải thành công, False nếu có lỗi
"""
try:
if 'git' not in self.models:
model_name = "microsoft/git-large-coco"
# Sử dụng AutoProcessor và AutoModelForCausalLM
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
self.models['git'] = model
self.processors['git'] = processor
print("Microsoft GIT model loaded successfully")
return True
except Exception as e:
print(f"Error loading Microsoft GIT: {e}")
return False
# =========================================================================
# PHẦN 2: CÁC HÀM SINH MÔ TẢ (PREDICTION FUNCTIONS)
# =========================================================================
def predict_vit_gpt2(self, image, max_length=50, num_beams=4, temperature=1.0,
top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=False):
"""
Sinh mô tả ảnh sử dụng mô hình ViT-GPT2.
Args:
image (PIL.Image): Ảnh đầu vào cần sinh mô tả
max_length (int): Độ dài tối đa của caption (số token)
num_beams (int): Số beam cho beam search (chỉ dùng khi do_sample=False)
temperature (float): Độ ngẫu nhiên (cao = sáng tạo hơn)
top_k (int): Số từ được xem xét ở mỗi bước sinh
top_p (float): Ngưỡng nucleus sampling (0.0-1.0)
repetition_penalty (float): Hệ số phạt khi lặp từ (>1.0 = phạt mạnh hơn)
do_sample (bool): True = sampling ngẫu nhiên, False = beam search
Returns:
str: Mô tả được sinh ra (đã capitalize)
Giải thích các tham số:
- temperature: Điều khiển độ "sáng tạo".
Thấp (0.1-0.5) = chắc chắn, cao (1.0-2.0) = đa dạng
- top_k: Chỉ xét K từ có xác suất cao nhất
- top_p: Chỉ xét các từ cho đến khi tổng xác suất đạt p
- num_beams: Số nhánh trong beam search (cao = chất lượng tốt hơn, chậm hơn)
"""
# Kiểm tra model đã tải chưa
if 'vit_gpt2' not in self.models:
return "Model not loaded"
model = self.models['vit_gpt2']
feature_extractor, tokenizer = self.processors['vit_gpt2']
# Chuyển ảnh sang RGB nếu cần (một số ảnh có thể là RGBA hoặc grayscale)
if image.mode != "RGB":
image = image.convert("RGB")
# Tiền xử lý ảnh: resize, normalize, chuyển thành tensor
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(self.device)
# Sinh caption với gradient tắt (inference mode)
with torch.no_grad():
# Cấu hình các tham số sinh văn bản
gen_kwargs = {
"max_length": max_length,
"early_stopping": True, # Dừng sớm khi gặp token kết thúc
"repetition_penalty": repetition_penalty,
}
# Nếu dùng sampling (ngẫu nhiên)
if do_sample:
gen_kwargs.update({
"do_sample": True,
"temperature": temperature,
"top_k": top_k if top_k > 0 else None, # 0 = không giới hạn
"top_p": top_p,
"num_beams": 1, # Sampling không dùng beam search
})
# Nếu dùng beam search (deterministic)
else:
gen_kwargs.update({
"do_sample": False,
"num_beams": num_beams,
})
# Gọi hàm generate để sinh caption
output_ids = model.generate(pixel_values, **gen_kwargs)
# Giải mã token IDs thành văn bản, bỏ các token đặc biệt
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return caption.capitalize() # Viết hoa chữ cái đầu
def predict_blip_large(self, image, max_length=50, num_beams=5, temperature=0.7,
top_k=50, top_p=0.9, repetition_penalty=1.0, do_sample=True):
"""
Sinh mô tả ảnh sử dụng mô hình BLIP-Large.
Tương tự predict_vit_gpt2 nhưng sử dụng kiến trúc BLIP.
BLIP thường cho kết quả chính xác hơn nhờ được huấn luyện
trên dữ liệu chất lượng cao.
Args:
(Tương tự predict_vit_gpt2)
Returns:
str: Mô tả được sinh ra
"""
if 'blip_large' not in self.models:
return "Model not loaded"
model = self.models['blip_large']
processor = self.processors['blip_large']
# Chuyển sang RGB nếu cần
if image.mode != "RGB":
image = image.convert("RGB")
# Tiền xử lý ảnh với BLIP processor
inputs = processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
gen_kwargs = {
"max_length": max_length,
"repetition_penalty": repetition_penalty,
}
if do_sample:
gen_kwargs.update({
"do_sample": True,
"temperature": temperature,
"top_k": top_k if top_k > 0 else None,
"top_p": top_p,
"num_beams": 1,
})
else:
gen_kwargs.update({
"do_sample": False,
"num_beams": num_beams,
})
# BLIP sử dụng **inputs thay vì pixel_values riêng
output_ids = model.generate(**inputs, **gen_kwargs)
# Giải mã kết quả
caption = processor.decode(output_ids[0], skip_special_tokens=True)
return caption.capitalize()
def predict_git(self, image, max_length=50, num_beams=5, temperature=0.7,
top_k=50, top_p=0.9, repetition_penalty=1.0, do_sample=True):
"""
Sinh mô tả ảnh sử dụng mô hình Microsoft GIT.
GIT (Generative Image-to-text Transformer) có kiến trúc
đơn giản với một transformer duy nhất xử lý cả ảnh và text.
Args:
(Tương tự predict_vit_gpt2)
Returns:
str: Mô tả được sinh ra
"""
if 'git' not in self.models:
return "Model not loaded"
model = self.models['git']
processor = self.processors['git']
if image.mode != "RGB":
image = image.convert("RGB")
# Tiền xử lý ảnh
inputs = processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
gen_kwargs = {
"max_length": max_length,
"repetition_penalty": repetition_penalty,
}
if do_sample:
gen_kwargs.update({
"do_sample": True,
"temperature": temperature,
"top_k": top_k if top_k > 0 else None,
"top_p": top_p,
"num_beams": 1,
})
else:
gen_kwargs.update({
"do_sample": False,
"num_beams": num_beams,
})
# GIT sử dụng pixel_values làm input
output_ids = model.generate(pixel_values=inputs.pixel_values, **gen_kwargs)
# GIT sử dụng batch_decode và lấy phần tử đầu tiên
caption = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
return caption.capitalize()
# =========================================================================
# PHẦN 3: CÁC HÀM TIỆN ÍCH (UTILITY FUNCTIONS)
# =========================================================================
def predict(self, model_name, image, **kwargs):
"""
Hàm dự đoán thống nhất - gọi model tương ứng dựa trên tên.
Giúp đơn giản hóa việc gọi các model khác nhau với cùng interface.
Args:
model_name (str): Tên model ("ViT-GPT2", "BLIP-Large", hoặc "GIT")
image (PIL.Image): Ảnh đầu vào
**kwargs: Các tham số sinh caption khác
Returns:
str: Mô tả được sinh ra
Ví dụ:
>>> caption = manager.predict("BLIP-Large", image, max_length=60)
"""
if model_name == "ViT-GPT2":
return self.predict_vit_gpt2(image, **kwargs)
elif model_name == "BLIP-Large":
return self.predict_blip_large(image, **kwargs)
elif model_name == "GIT":
return self.predict_git(image, **kwargs)
else:
return f"Model {model_name} not supported"
def unload_model(self, model_name):
"""
Giải phóng một model khỏi bộ nhớ.
Hữu ích khi cần tiết kiệm RAM/VRAM, đặc biệt trên các máy
có tài nguyên hạn chế.
Args:
model_name (str): Tên model cần giải phóng
Returns:
bool: True nếu giải phóng thành công, False nếu model không tồn tại
Ví dụ:
>>> manager.unload_model("ViT-GPT2") # Giải phóng ViT-GPT2
"""
# Chuyển tên model về dạng key (viết thường, thay - bằng _)
model_key = model_name.lower().replace("-", "_")
if model_key in self.models:
# Xóa model và processor khỏi dictionary
del self.models[model_key]
del self.processors[model_key]
# Giải phóng bộ nhớ GPU nếu đang dùng
torch.cuda.empty_cache()
print(f"{model_name} unloaded")
return True
return False