# 🖼️ Mô hình dự đoán chữ số viết tay ## 📝 Mô tả Đây là mô hình Vision Transformer (ViT‑Base với patch size 32) được fine-tuned từ openai/clip-vit-base-patch32 để thực hiện phân loại chữ số viết tay (MNIST). Chỉ phần vision encoder được training lại, giữ nguyên text encoder để giữ khả năng zero-shot của CLIP. ## 📌 Nhiệm vụ Dự đoán chữ số (0–9) từ ảnh MNIST, dưới dạng phân loại đơn giản gồm 10 lớp. ## 📥 Đầu vào Ảnh xám (grayscale) kích thước 28×28, mô hình sẽ tự xử lý chuẩn hóa/chuyển sang 3 kênh nếu cần (qua processor của CLIP). Đầu vào sẽ được đưa vào dưới dạng tensor [batch_size, 3, 224, 224] sau khi qua CLIPProcessor. ## 📤 Đầu ra logits có kích thước [batch_size, 10], đại diện xác suất tương ứng với mỗi chữ số từ 0 đến 9. ## 🧪 Kết quả đánh giá Giai đoạn Accuracy Pre-trained (chưa fine-tune) 47.6% Sau fine-tune 99.57% ## 🛠 Yêu cầu thư viện Cài đặt các thư viện cần thiết: ```bash pip install torch transformers datasets pillow ``` ## 🚀 Cách sử dụng ### 🎯 Sử dụng encoder đã fine-tuned ```python import torch from transformers import CLIPVisionModel, CLIPProcessor from PIL import Image # Tải vision encoder và CLIP processor vision_model = CLIPVisionModel.from_pretrained("zhaospei/Model_11") processor = CLIPProcessor.from_pretrained("zhaospei/Model_10") # Chuẩn bị ảnh MNIST (28×28) img = Image.open("path_to_mnist_digit.png").convert("L") # ảnh xám img = img.resize((224, 224)).convert("RGB") # mở rộng thành RGB 3 kênh inputs = processor(images=img, return_tensors="pt") # Lấy embedding từ ảnh with torch.no_grad(): vision_outputs = vision_model(**inputs) image_embeds = vision_outputs.last_hidden_state[:, 0, :] # CLS token embedding print("Image embedding shape:", image_embeds.shape) ``` ### 🔄 Kết hợp với CLIP để trên nền zero-shot ```python from transformers import CLIPModel # Tải CLIP đầy đủ clip = CLIPModel.from_pretrained("zhaospei/Model_11") # Thay thành encoder đã fine-tune clip.vision_model.load_state_dict(vision_model.vision_model.state_dict()) # Ví dụ zero-shot MNIST from PIL import Image img = Image.open("path_to_mnist_digit.png").convert("L").resize((224, 224)).convert("RGB") texts = [str(i) for i in range(10)] inputs = processor(text=texts, images=img, return_tensors="pt", padding=True) with torch.no_grad(): outputs = clip(**inputs) probs = outputs.logits_per_image.softmax(dim=1)[0] print({texts[i]: float(probs[i]) for i in range(10)}) ``` ## ⚙️ Thông tin huấn luyện Optimizer: Adam, learning rate = 1e-5 Batch size: 32 Số bước huấn luyện: 4000 Chỉ fine-tune vision encode