Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,45 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
datasets:
|
| 4 |
+
- cifar10
|
| 5 |
+
metrics:
|
| 6 |
+
- accuracy
|
| 7 |
+
library_name: pytorch
|
| 8 |
+
tags:
|
| 9 |
+
- image-captioning
|
| 10 |
+
- resnet18
|
| 11 |
+
- lstm
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# ResNet18 Image Captioning Weights (CIFAR-10)
|
| 15 |
+
|
| 16 |
+
This repository contains the trained weights for an image captioning system consisting of a **CNN Encoder** and an **RNN Decoder**, fine-tuned on the CIFAR-10 dataset.
|
| 17 |
+
|
| 18 |
+
## 📦 Model Components
|
| 19 |
+
|
| 20 |
+
### 1. Encoder (`encoder`)
|
| 21 |
+
- **Architecture:** ResNet18 (Feature Extractor)
|
| 22 |
+
- **Output Dim:** 256
|
| 23 |
+
- **Purpose:** Extracts high-level visual features from input images. The final fully connected layer was replaced to map features to the embedding space.
|
| 24 |
+
|
| 25 |
+
### 2. Decoder (`decoder`)
|
| 26 |
+
- **Architecture:** LSTM-based RNN
|
| 27 |
+
- **Hidden Dim:** 512
|
| 28 |
+
- **Embedding Dim:** 256
|
| 29 |
+
- **Purpose:** Generates descriptive sequences based on the features received from the Encoder.
|
| 30 |
+
|
| 31 |
+
## 🚀 Usage
|
| 32 |
+
|
| 33 |
+
You can load these weights directly using the `huggingface_hub` library in Python:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from huggingface_hub import hf_hub_download
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
# Download weights
|
| 40 |
+
encoder_path = hf_hub_download(repo_id="Sher1988/image-classifier-weights", filename="encoder")
|
| 41 |
+
decoder_path = hf_hub_download(repo_id="Sher1988/image-classifier-weights", filename="decoder")
|
| 42 |
+
|
| 43 |
+
# Load into your model classes
|
| 44 |
+
# encoder.load_state_dict(torch.load(encoder_path, map_location='cpu'))
|
| 45 |
+
# decoder.load_state_dict(torch.load(decoder_path, map_location='cpu'))
|