Spaces:
Sleeping
Sleeping
| """ | |
| ============================================================================= | |
| utils.py - Module quản lý các mô hình Image Captioning | |
| ============================================================================= | |
| File này chứa class ImageCaptioningModels để quản lý việc tải và sử dụng | |
| 3 mô hình sinh mô tả ảnh: | |
| 1. ViT-GPT2: Vision Transformer + GPT-2 (nlpconnect/vit-gpt2-image-captioning) | |
| 2. BLIP-Large: Salesforce BLIP (Salesforce/blip-image-captioning-large) | |
| 3. GIT: Microsoft Generative Image-to-text (microsoft/git-large-coco) | |
| Các tính năng chính: | |
| - Tải mô hình theo yêu cầu (lazy loading) để tiết kiệm bộ nhớ | |
| - Hỗ trợ GPU (CUDA) nếu có sẵn | |
| - Sinh mô tả với các tham số có thể điều chỉnh (temperature, top_k, top_p, ...) | |
| - Giải phóng bộ nhớ khi không cần thiết | |
| Tác giả: Đồ án 2 - 2024 | |
| ============================================================================= | |
| """ | |
| # ============================================================================= | |
| # IMPORT CÁC THƯ VIỆN CẦN THIẾT | |
| # ============================================================================= | |
| import torch # Thư viện deep learning chính | |
| from PIL import Image # Xử lý hình ảnh | |
| # Import các class từ thư viện Transformers của Hugging Face | |
| from transformers import ( | |
| # Cho mô hình ViT-GPT2 | |
| VisionEncoderDecoderModel, # Kiến trúc encoder-decoder kết hợp vision và text | |
| ViTImageProcessor, # Tiền xử lý ảnh cho Vision Transformer | |
| AutoTokenizer, # Tokenizer tự động cho text | |
| # Cho mô hình BLIP | |
| BlipProcessor, # Tiền xử lý cho BLIP (cả ảnh và text) | |
| BlipForConditionalGeneration, # Mô hình BLIP cho sinh caption | |
| # Cho mô hình GIT | |
| AutoProcessor, # Processor tự động | |
| AutoModelForCausalLM # Mô hình ngôn ngữ tự động hồi quy | |
| ) | |
| import warnings | |
| warnings.filterwarnings('ignore') # Tắt các cảnh báo không cần thiết | |
| # ============================================================================= | |
| # CLASS IMAGEСAPTIONINGMODELS - QUẢN LÝ CÁC MÔ HÌNH SINH MÔ TẢ ẢNH | |
| # ============================================================================= | |
| class ImageCaptioningModels: | |
| """ | |
| Class quản lý việc tải và sử dụng các mô hình sinh mô tả ảnh. | |
| Attributes: | |
| models (dict): Dictionary lưu trữ các mô hình đã tải | |
| processors (dict): Dictionary lưu trữ các processor/tokenizer tương ứng | |
| device (torch.device): Thiết bị chạy mô hình (CPU hoặc GPU) | |
| Ví dụ sử dụng: | |
| >>> manager = ImageCaptioningModels() | |
| >>> manager.load_vit_gpt2() # Tải mô hình ViT-GPT2 | |
| >>> caption = manager.predict_vit_gpt2(image) # Sinh caption | |
| """ | |
| def __init__(self): | |
| """ | |
| Khởi tạo ImageCaptioningModels. | |
| - Tạo dictionary rỗng để lưu models và processors | |
| - Tự động phát hiện và sử dụng GPU nếu có (CUDA) | |
| - In ra thiết bị đang sử dụng | |
| """ | |
| self.models = {} # Lưu trữ các mô hình đã tải {tên: model} | |
| self.processors = {} # Lưu trữ processor tương ứng {tên: processor} | |
| # Kiểm tra GPU có sẵn không, nếu có thì dùng GPU, không thì dùng CPU | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| # ========================================================================= | |
| # PHẦN 1: CÁC HÀM TẢI MÔ HÌNH (LOAD MODELS) | |
| # ========================================================================= | |
| def load_vit_gpt2(self): | |
| """ | |
| Tải mô hình ViT-GPT2 từ Hugging Face Hub. | |
| Mô hình ViT-GPT2 kết hợp: | |
| - Vision Transformer (ViT): Mã hóa ảnh thành vector đặc trưng | |
| - GPT-2: Sinh văn bản mô tả từ vector đặc trưng | |
| Returns: | |
| bool: True nếu tải thành công, False nếu có lỗi | |
| Lưu ý: | |
| - Chỉ tải nếu chưa tải trước đó (kiểm tra trong self.models) | |
| - Mô hình được chuyển sang GPU nếu có sẵn | |
| """ | |
| try: | |
| # Kiểm tra xem model đã được tải chưa để tránh tải lại | |
| if 'vit_gpt2' not in self.models: | |
| model_name = "nlpconnect/vit-gpt2-image-captioning" | |
| # Tải mô hình VisionEncoderDecoder và chuyển sang device (GPU/CPU) | |
| model = VisionEncoderDecoderModel.from_pretrained(model_name).to(self.device) | |
| # Tải feature extractor (tiền xử lý ảnh) và tokenizer (xử lý text) | |
| feature_extractor = ViTImageProcessor.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Lưu vào dictionary để sử dụng sau | |
| self.models['vit_gpt2'] = model | |
| self.processors['vit_gpt2'] = (feature_extractor, tokenizer) # Tuple gồm 2 thành phần | |
| print("ViT-GPT2 model loaded successfully") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading ViT-GPT2: {e}") | |
| return False | |
| def load_blip_large(self): | |
| """ | |
| Tải mô hình BLIP-Large từ Salesforce. | |
| BLIP (Bootstrapping Language-Image Pre-training) là mô hình | |
| multimodal được huấn luyện trên dữ liệu lớn với kỹ thuật | |
| bootstrapping để cải thiện chất lượng caption. | |
| Returns: | |
| bool: True nếu tải thành công, False nếu có lỗi | |
| """ | |
| try: | |
| if 'blip_large' not in self.models: | |
| model_name = "Salesforce/blip-image-captioning-large" | |
| # BLIP chỉ cần 1 processor (xử lý cả ảnh và text) | |
| processor = BlipProcessor.from_pretrained(model_name) | |
| model = BlipForConditionalGeneration.from_pretrained(model_name).to(self.device) | |
| self.models['blip_large'] = model | |
| self.processors['blip_large'] = processor | |
| print("BLIP-Large model loaded successfully") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading BLIP-Large: {e}") | |
| return False | |
| def load_git(self): | |
| """ | |
| Tải mô hình Microsoft GIT (Generative Image-to-text Transformer). | |
| GIT là mô hình của Microsoft với kiến trúc đơn giản nhưng hiệu quả, | |
| được huấn luyện trên tập COCO dataset. | |
| Returns: | |
| bool: True nếu tải thành công, False nếu có lỗi | |
| """ | |
| try: | |
| if 'git' not in self.models: | |
| model_name = "microsoft/git-large-coco" | |
| # Sử dụng AutoProcessor và AutoModelForCausalLM | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device) | |
| self.models['git'] = model | |
| self.processors['git'] = processor | |
| print("Microsoft GIT model loaded successfully") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading Microsoft GIT: {e}") | |
| return False | |
| # ========================================================================= | |
| # PHẦN 2: CÁC HÀM SINH MÔ TẢ (PREDICTION FUNCTIONS) | |
| # ========================================================================= | |
| def predict_vit_gpt2(self, image, max_length=50, num_beams=4, temperature=1.0, | |
| top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=False): | |
| """ | |
| Sinh mô tả ảnh sử dụng mô hình ViT-GPT2. | |
| Args: | |
| image (PIL.Image): Ảnh đầu vào cần sinh mô tả | |
| max_length (int): Độ dài tối đa của caption (số token) | |
| num_beams (int): Số beam cho beam search (chỉ dùng khi do_sample=False) | |
| temperature (float): Độ ngẫu nhiên (cao = sáng tạo hơn) | |
| top_k (int): Số từ được xem xét ở mỗi bước sinh | |
| top_p (float): Ngưỡng nucleus sampling (0.0-1.0) | |
| repetition_penalty (float): Hệ số phạt khi lặp từ (>1.0 = phạt mạnh hơn) | |
| do_sample (bool): True = sampling ngẫu nhiên, False = beam search | |
| Returns: | |
| str: Mô tả được sinh ra (đã capitalize) | |
| Giải thích các tham số: | |
| - temperature: Điều khiển độ "sáng tạo". | |
| Thấp (0.1-0.5) = chắc chắn, cao (1.0-2.0) = đa dạng | |
| - top_k: Chỉ xét K từ có xác suất cao nhất | |
| - top_p: Chỉ xét các từ cho đến khi tổng xác suất đạt p | |
| - num_beams: Số nhánh trong beam search (cao = chất lượng tốt hơn, chậm hơn) | |
| """ | |
| # Kiểm tra model đã tải chưa | |
| if 'vit_gpt2' not in self.models: | |
| return "Model not loaded" | |
| model = self.models['vit_gpt2'] | |
| feature_extractor, tokenizer = self.processors['vit_gpt2'] | |
| # Chuyển ảnh sang RGB nếu cần (một số ảnh có thể là RGBA hoặc grayscale) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Tiền xử lý ảnh: resize, normalize, chuyển thành tensor | |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(self.device) | |
| # Sinh caption với gradient tắt (inference mode) | |
| with torch.no_grad(): | |
| # Cấu hình các tham số sinh văn bản | |
| gen_kwargs = { | |
| "max_length": max_length, | |
| "early_stopping": True, # Dừng sớm khi gặp token kết thúc | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| # Nếu dùng sampling (ngẫu nhiên) | |
| if do_sample: | |
| gen_kwargs.update({ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k if top_k > 0 else None, # 0 = không giới hạn | |
| "top_p": top_p, | |
| "num_beams": 1, # Sampling không dùng beam search | |
| }) | |
| # Nếu dùng beam search (deterministic) | |
| else: | |
| gen_kwargs.update({ | |
| "do_sample": False, | |
| "num_beams": num_beams, | |
| }) | |
| # Gọi hàm generate để sinh caption | |
| output_ids = model.generate(pixel_values, **gen_kwargs) | |
| # Giải mã token IDs thành văn bản, bỏ các token đặc biệt | |
| caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return caption.capitalize() # Viết hoa chữ cái đầu | |
| def predict_blip_large(self, image, max_length=50, num_beams=5, temperature=0.7, | |
| top_k=50, top_p=0.9, repetition_penalty=1.0, do_sample=True): | |
| """ | |
| Sinh mô tả ảnh sử dụng mô hình BLIP-Large. | |
| Tương tự predict_vit_gpt2 nhưng sử dụng kiến trúc BLIP. | |
| BLIP thường cho kết quả chính xác hơn nhờ được huấn luyện | |
| trên dữ liệu chất lượng cao. | |
| Args: | |
| (Tương tự predict_vit_gpt2) | |
| Returns: | |
| str: Mô tả được sinh ra | |
| """ | |
| if 'blip_large' not in self.models: | |
| return "Model not loaded" | |
| model = self.models['blip_large'] | |
| processor = self.processors['blip_large'] | |
| # Chuyển sang RGB nếu cần | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Tiền xử lý ảnh với BLIP processor | |
| inputs = processor(image, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| gen_kwargs = { | |
| "max_length": max_length, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| if do_sample: | |
| gen_kwargs.update({ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k if top_k > 0 else None, | |
| "top_p": top_p, | |
| "num_beams": 1, | |
| }) | |
| else: | |
| gen_kwargs.update({ | |
| "do_sample": False, | |
| "num_beams": num_beams, | |
| }) | |
| # BLIP sử dụng **inputs thay vì pixel_values riêng | |
| output_ids = model.generate(**inputs, **gen_kwargs) | |
| # Giải mã kết quả | |
| caption = processor.decode(output_ids[0], skip_special_tokens=True) | |
| return caption.capitalize() | |
| def predict_git(self, image, max_length=50, num_beams=5, temperature=0.7, | |
| top_k=50, top_p=0.9, repetition_penalty=1.0, do_sample=True): | |
| """ | |
| Sinh mô tả ảnh sử dụng mô hình Microsoft GIT. | |
| GIT (Generative Image-to-text Transformer) có kiến trúc | |
| đơn giản với một transformer duy nhất xử lý cả ảnh và text. | |
| Args: | |
| (Tương tự predict_vit_gpt2) | |
| Returns: | |
| str: Mô tả được sinh ra | |
| """ | |
| if 'git' not in self.models: | |
| return "Model not loaded" | |
| model = self.models['git'] | |
| processor = self.processors['git'] | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Tiền xử lý ảnh | |
| inputs = processor(images=image, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| gen_kwargs = { | |
| "max_length": max_length, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| if do_sample: | |
| gen_kwargs.update({ | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": top_k if top_k > 0 else None, | |
| "top_p": top_p, | |
| "num_beams": 1, | |
| }) | |
| else: | |
| gen_kwargs.update({ | |
| "do_sample": False, | |
| "num_beams": num_beams, | |
| }) | |
| # GIT sử dụng pixel_values làm input | |
| output_ids = model.generate(pixel_values=inputs.pixel_values, **gen_kwargs) | |
| # GIT sử dụng batch_decode và lấy phần tử đầu tiên | |
| caption = processor.batch_decode(output_ids, skip_special_tokens=True)[0] | |
| return caption.capitalize() | |
| # ========================================================================= | |
| # PHẦN 3: CÁC HÀM TIỆN ÍCH (UTILITY FUNCTIONS) | |
| # ========================================================================= | |
| def predict(self, model_name, image, **kwargs): | |
| """ | |
| Hàm dự đoán thống nhất - gọi model tương ứng dựa trên tên. | |
| Giúp đơn giản hóa việc gọi các model khác nhau với cùng interface. | |
| Args: | |
| model_name (str): Tên model ("ViT-GPT2", "BLIP-Large", hoặc "GIT") | |
| image (PIL.Image): Ảnh đầu vào | |
| **kwargs: Các tham số sinh caption khác | |
| Returns: | |
| str: Mô tả được sinh ra | |
| Ví dụ: | |
| >>> caption = manager.predict("BLIP-Large", image, max_length=60) | |
| """ | |
| if model_name == "ViT-GPT2": | |
| return self.predict_vit_gpt2(image, **kwargs) | |
| elif model_name == "BLIP-Large": | |
| return self.predict_blip_large(image, **kwargs) | |
| elif model_name == "GIT": | |
| return self.predict_git(image, **kwargs) | |
| else: | |
| return f"Model {model_name} not supported" | |
| def unload_model(self, model_name): | |
| """ | |
| Giải phóng một model khỏi bộ nhớ. | |
| Hữu ích khi cần tiết kiệm RAM/VRAM, đặc biệt trên các máy | |
| có tài nguyên hạn chế. | |
| Args: | |
| model_name (str): Tên model cần giải phóng | |
| Returns: | |
| bool: True nếu giải phóng thành công, False nếu model không tồn tại | |
| Ví dụ: | |
| >>> manager.unload_model("ViT-GPT2") # Giải phóng ViT-GPT2 | |
| """ | |
| # Chuyển tên model về dạng key (viết thường, thay - bằng _) | |
| model_key = model_name.lower().replace("-", "_") | |
| if model_key in self.models: | |
| # Xóa model và processor khỏi dictionary | |
| del self.models[model_key] | |
| del self.processors[model_key] | |
| # Giải phóng bộ nhớ GPU nếu đang dùng | |
| torch.cuda.empty_cache() | |
| print(f"{model_name} unloaded") | |
| return True | |
| return False |