File size: 2,877 Bytes
6d9fd27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# 🖼️ Mô hình dự đoán chữ số viết tay
## 📝 Mô tả
Đây là mô hình Vision Transformer (ViT‑Base với patch size 32) được fine-tuned từ openai/clip-vit-base-patch32 để thực hiện phân loại chữ số viết tay (MNIST). Chỉ phần vision encoder được training lại, giữ nguyên text encoder để giữ khả năng zero-shot của CLIP.

## 📌 Nhiệm vụ
Dự đoán chữ số (0–9) từ ảnh MNIST, dưới dạng phân loại đơn giản gồm 10 lớp.

## 📥 Đầu vào
Ảnh xám (grayscale) kích thước 28×28, mô hình sẽ tự xử lý chuẩn hóa/chuyển sang 3 kênh nếu cần (qua processor của CLIP).
Đầu vào sẽ được đưa vào dưới dạng tensor [batch_size, 3, 224, 224] sau khi qua CLIPProcessor.

## 📤 Đầu ra
logits có kích thước [batch_size, 10], đại diện xác suất tương ứng với mỗi chữ số từ 0 đến 9.

## 🧪 Kết quả đánh giá
Giai đoạn	Accuracy
Pre-trained (chưa fine-tune)	47.6%
Sau fine-tune	99.57%

## 🛠 Yêu cầu thư viện
Cài đặt các thư viện cần thiết:

```bash
pip install torch transformers datasets pillow
```

## 🚀 Cách sử dụng
### 🎯 Sử dụng encoder đã fine-tuned
```python
import torch
from transformers import CLIPVisionModel, CLIPProcessor
from PIL import Image

# Tải vision encoder và CLIP processor
vision_model = CLIPVisionModel.from_pretrained("zhaospei/Model_11")
processor = CLIPProcessor.from_pretrained("zhaospei/Model_10")

# Chuẩn bị ảnh MNIST (28×28)
img = Image.open("path_to_mnist_digit.png").convert("L")  # ảnh xám
img = img.resize((224, 224)).convert("RGB")  # mở rộng thành RGB 3 kênh

inputs = processor(images=img, return_tensors="pt")

# Lấy embedding từ ảnh
with torch.no_grad():
    vision_outputs = vision_model(**inputs)

image_embeds = vision_outputs.last_hidden_state[:, 0, :]  # CLS token embedding
print("Image embedding shape:", image_embeds.shape)
```

### 🔄 Kết hợp với CLIP để trên nền zero-shot

```python
from transformers import CLIPModel

# Tải CLIP đầy đủ
clip = CLIPModel.from_pretrained("zhaospei/Model_11")
# Thay thành encoder đã fine-tune
clip.vision_model.load_state_dict(vision_model.vision_model.state_dict())

# Ví dụ zero-shot MNIST
from PIL import Image
img = Image.open("path_to_mnist_digit.png").convert("L").resize((224, 224)).convert("RGB")
texts = [str(i) for i in range(10)]

inputs = processor(text=texts, images=img, return_tensors="pt", padding=True)
with torch.no_grad():
    outputs = clip(**inputs)

probs = outputs.logits_per_image.softmax(dim=1)[0]
print({texts[i]: float(probs[i]) for i in range(10)})
```

## ⚙️ Thông tin huấn luyện
Optimizer: Adam, learning rate = 1e-5
Batch size: 32
Số bước huấn luyện: 4000
Chỉ fine-tune vision encode