| --- |
| license: mit |
| datasets: |
| - cifar10 |
| metrics: |
| - accuracy |
| library_name: pytorch |
| tags: |
| - image-captioning |
| - resnet18 |
| - lstm |
| --- |
| |
| # ResNet18 Image Captioning Weights (CIFAR-10) |
|
|
| This repository contains the trained weights for an image captioning system consisting of a **CNN Encoder** and an **RNN Decoder**, fine-tuned on the CIFAR-10 dataset. |
|
|
| ## 📦 Model Components |
|
|
| ### 1. Encoder (`encoder`) |
| - **Architecture:** ResNet18 (Feature Extractor) |
| - **Output Dim:** 256 |
| - **Purpose:** Extracts high-level visual features from input images. The final fully connected layer was replaced to map features to the embedding space. |
|
|
| ### 2. Decoder (`decoder`) |
| - **Architecture:** LSTM-based RNN |
| - **Hidden Dim:** 512 |
| - **Embedding Dim:** 256 |
| - **Purpose:** Generates descriptive sequences based on the features received from the Encoder. |
|
|
| ## 🚀 Usage |
|
|
| You can load these weights directly using the `huggingface_hub` library in Python: |
|
|
| ```python |
| from huggingface_hub import hf_hub_download |
| import torch |
| |
| # Download weights |
| encoder_path = hf_hub_download(repo_id="Sher1988/image-classifier-weights", filename="encoder") |
| decoder_path = hf_hub_download(repo_id="Sher1988/image-classifier-weights", filename="decoder") |
| |
| # Load into your model classes |
| # encoder.load_state_dict(torch.load(encoder_path, map_location='cpu')) |
| # decoder.load_state_dict(torch.load(decoder_path, map_location='cpu')) |
| |