erjui commited on
Commit
598f46d
·
verified ·
1 Parent(s): a01185a

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +69 -17
README.md CHANGED
@@ -13,11 +13,9 @@ library_name: pytorch
13
  pipeline_tag: image-classification
14
  ---
15
 
16
- # DHO: Simple Few-shot Semi-supervised Knowledge Distillation
17
 
18
  [![arXiv](https://img.shields.io/badge/arXiv-2505.07675v1-b31b1b.svg)](https://arxiv.org/abs/2505.07675v1)
19
- [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/simple-semi-supervised-knowledge-distillation/semi-supervised-image-classification-on-1)](https://paperswithcode.com/sota/semi-supervised-image-classification-on-1?p=simple-semi-supervised-knowledge-distillation)
20
- [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/simple-semi-supervised-knowledge-distillation/semi-supervised-image-classification-on-2)](https://paperswithcode.com/sota/semi-supervised-image-classification-on-2?p=simple-semi-supervised-knowledge-distillation)
21
 
22
  This repository contains pretrained checkpoints for **DHO (Dual-Head Optimization)**, a simple yet effective approach for semi-supervised knowledge distillation from Vision-Language Models.
23
 
@@ -52,33 +50,87 @@ The method achieves state-of-the-art performance on ImageNet semi-supervised lea
52
 
53
  ```python
54
  import torch
 
 
55
  import clip
56
-
57
- # Load the student model architecture
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
59
 
60
- # For ViT-B/16 checkpoints
61
- model, preprocess = clip.load("ViT-B-16", device=device)
62
 
63
- # Load DHO checkpoint
64
- checkpoint = torch.hub.load_state_dict_from_url(
65
- "https://huggingface.co/erjui/dho/resolve/main/vit_b_10.pt",
66
- map_location=device
67
- )
68
 
69
- # Load the state dict
70
- model.load_state_dict(checkpoint['model_state_dict'])
 
71
  model.eval()
72
 
73
- # Use the model for inference
74
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
75
 
76
  image = preprocess(Image.open("path/to/image.jpg")).unsqueeze(0).to(device)
77
  with torch.no_grad():
78
- image_features = model.encode_image(image)
79
- # ... your inference code
 
 
 
 
 
 
 
80
  ```
81
 
 
 
 
 
 
 
82
  ### Training Your Own Model
83
 
84
  To train your own DHO model, please visit the [official GitHub repository](https://github.com/yourusername/DHO) for detailed instructions and training scripts.
 
13
  pipeline_tag: image-classification
14
  ---
15
 
16
+ # DHO: Simple yet Effective Semi-supervised Knowledge Distillation from Vision-Language Models via Dual-Head Optimization
17
 
18
  [![arXiv](https://img.shields.io/badge/arXiv-2505.07675v1-b31b1b.svg)](https://arxiv.org/abs/2505.07675v1)
 
 
19
 
20
  This repository contains pretrained checkpoints for **DHO (Dual-Head Optimization)**, a simple yet effective approach for semi-supervised knowledge distillation from Vision-Language Models.
21
 
 
50
 
51
  ```python
52
  import torch
53
+ import torch.nn as nn
54
+ import torch.nn.functional as F
55
  import clip
56
+ from huggingface_hub import hf_hub_download
57
+
58
+ # Define the DHO StudentModel architecture with dual heads
59
+ class StudentModel(nn.Module):
60
+ def __init__(self, num_classes=1000, model_name='ViT-B-16'):
61
+ super().__init__()
62
+ # Load CLIP backbone
63
+ clip_model, _ = clip.load(model_name, device='cpu')
64
+ self.backbone = clip_model.float().visual
65
+
66
+ # Feature dimensions per architecture
67
+ in_features = {
68
+ 'RN50': 1024,
69
+ 'ViT-B-16': 512,
70
+ 'ViT-L-14': 768,
71
+ 'ViT-L-14-336px': 768
72
+ }[model_name]
73
+
74
+ # Dual-head architecture
75
+ self.ce_head = nn.Linear(in_features, num_classes) # CE branch
76
+ self.kd_head = nn.Linear(in_features, num_classes) # KD branch
77
+
78
+ def forward(self, x):
79
+ features = self.backbone(x)
80
+ ce_out = self.ce_head(features)
81
+ kd_out = self.kd_head(F.normalize(features, dim=1)) * 100
82
+ return ce_out, kd_out
83
+
84
+ # Download and load checkpoint
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ checkpoint_path = hf_hub_download(repo_id="erjui/dho", filename="vit_b_10.pt")
87
+ checkpoint = torch.load(checkpoint_path, map_location=device)
88
 
89
+ # Initialize model
90
+ model = StudentModel(num_classes=1000, model_name='ViT-B-16').to(device)
91
 
92
+ # Handle DDP wrapped state_dict
93
+ state_dict = checkpoint['model_state_dict']
94
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
95
+ model.load_state_dict(state_dict)
 
96
 
97
+ # Get optimal inference parameters
98
+ alpha = checkpoint['alpha'] # Weight for CE head
99
+ beta = checkpoint['beta'] # Temperature for KD head
100
  model.eval()
101
 
102
+ # Inference example
103
  from PIL import Image
104
+ import torchvision.transforms as transforms
105
+
106
+ # CLIP preprocessing
107
+ preprocess = transforms.Compose([
108
+ transforms.Resize(224),
109
+ transforms.CenterCrop(224),
110
+ transforms.ToTensor(),
111
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
112
+ std=(0.26862954, 0.26130258, 0.27577711))
113
+ ])
114
 
115
  image = preprocess(Image.open("path/to/image.jpg")).unsqueeze(0).to(device)
116
  with torch.no_grad():
117
+ ce_logits, kd_logits = model(image)
118
+
119
+ # Combine predictions using saved parameters
120
+ probs_ce = F.softmax(ce_logits, dim=1)
121
+ probs_kd = F.softmax(kd_logits / beta, dim=1)
122
+ probs = alpha * probs_ce + (1 - alpha) * probs_kd
123
+
124
+ predicted_class = probs.argmax(dim=1)
125
+ print(f"Predicted class: {predicted_class.item()}")
126
  ```
127
 
128
+ **Important Notes:**
129
+ - DHO checkpoints contain: `model_state_dict`, `epoch`, `acc`, `alpha`, `beta`
130
+ - The model has a **dual-head architecture** (CE head + KD head)
131
+ - Use the saved `alpha` and `beta` parameters for optimal inference
132
+ - For ViT-L checkpoints, change `model_name='ViT-L-14'` and use image size 224 (or 336 for ViT-L-14-336px)
133
+
134
  ### Training Your Own Model
135
 
136
  To train your own DHO model, please visit the [official GitHub repository](https://github.com/yourusername/DHO) for detailed instructions and training scripts.