Zero-Shot Image Classification
PyTorch
i3-clip
i3-CLIP / README.md
FlameF0X's picture
Update README.md
e04163a verified
---
license: apache-2.0
pipeline_tag: zero-shot-image-classification
datasets:
- MohamedRashad/midjourney-detailed-prompts
---
# [i3-CLIP](https://github.com/FlameF0X/open-i3/tree/main/src/zero-shot-image-classification/CLIP): Hybrid RWKV-Attention Vision-Language Model
## Model Description
**i3-CLIP** is a 180M parameter vision-language model that combines contrastive learning with a novel hybrid architecture. Built upon the foundation of i3-Ethan (SDPA - a 200M parameter text generator), i3-CLIP adapts the hybrid RWKV-Attention approach for multimodal understanding, enabling efficient image-text matching and representation learning.
## Model Details
- **Model Type**: Vision-Language Model (CLIP-style)
- **Total Parameters**: 180,150,081 (180M)
- **Base Architecture**: Modified i3-Ethan (a SDPA version of Ethan) hybrid architecture
- **Training Data**: Midjourney Detailed Prompts dataset
- **Model Dimensions**: 768-dimensional embeddings
- **Architecture Components**:
- **Vision Encoder**: ResNet-based CNN (16-layer deep)
- **Text Encoder**: Hybrid RWKV + Transformer (12 RWKV layers + 4 Attention layers)
- **Contrastive Learning**: Learnable temperature-scaled similarity
### Architecture Overview
```
Vision Encoder (CNN-based):
β”œβ”€β”€ Stem: Conv7x7 β†’ BatchNorm β†’ ReLU β†’ MaxPool
β”œβ”€β”€ Layer1: 3x ResBlocks (64 channels)
β”œβ”€β”€ Layer2: 4x ResBlocks (128 channels)
β”œβ”€β”€ Layer3: 6x ResBlocks (256 channels)
β”œβ”€β”€ Layer4: 3x ResBlocks (512 channels)
└── AdaptiveAvgPool β†’ Linear(512 β†’ 768)
Text Encoder (Hybrid):
β”œβ”€β”€ Token Embedding (vocab_size β†’ 768)
β”œβ”€β”€ Positional Embedding (77 positions)
β”œβ”€β”€ 12x RWKV Blocks (parallel linear attention)
β”œβ”€β”€ 4x Transformer Blocks (12 heads, multi-head attention)
└── Layer Norm β†’ Take last token
Output:
└── Normalized embeddings β†’ Contrastive loss
```
## Key Features
### 1. **Hybrid Text Processing**
- **RWKV Layers (12)**: Efficient linear-time attention for capturing sequential dependencies
- **Transformer Layers (4)**: Full self-attention for complex reasoning
- **Advantage**: Combines O(n) efficiency of RWKV with the expressiveness of transformers
### 2. **JIT-Optimized RWKV**
- Custom `torch.jit.script` implementation of parallel RWKV attention
- Efficient time-mixing and channel-mixing mechanisms
- Learnable decay rates and position-dependent mixing
### 3. **Contrastive Learning**
- Bidirectional image-text matching
- Learnable temperature parameter (initialized to log(1/0.07))
- Symmetric cross-entropy loss
### 4. **Efficient Vision Encoding**
- ResNet-inspired architecture with 16 residual blocks
- Progressive channel expansion (64 β†’ 128 β†’ 256 β†’ 512)
- Batch normalization and residual connections for stable training
## Training Details
- **Framework**: PyTorch with mixed precision training
- **Optimizer**: AdamW (lr=5e-5)
- **Batch Size**: 32
- **Sequence Length**: 77 tokens (CLIP standard)
- **Image Size**: 224Γ—224
- **Normalization**: ImageNet-style (mean=\[0.48, 0.45, 0.40], std=\[0.26, 0.26, 0.27])
- **Training Steps**: 2,000 iterations
- **Validation**: Every 100 steps
### Training Dataset
**Midjourney Detailed Prompts**: A dataset of AI-generated images with rich, detailed text descriptions, providing high-quality image-text pairs for contrastive learning.
## Relationship to i3-Ethan-SDPA
i3-CLIP is derived from **i3-Ethan-SDPA**, a 200M parameter text generation model with the following shared characteristics:
### Inherited from i3-Ethan-SDPA:
1. **Hybrid RWKV-Attention Architecture**: The core design philosophy of combining RWKV blocks (efficient sequential processing) with standard attention layers (complex reasoning)
2. **JIT-Optimized RWKV Implementation**: Same parallel linear attention kernel
3. **Time-Mixing Mechanisms**: Token-wise interpolation between current and previous states
4. **Channel-Mixing FFN**: Squared ReLU activation in feed-forward networks
### Key Adaptations for Vision-Language Tasks:
1. **Dual Encoder Design**: Separate vision and text pathways (vs. single autoregressive decoder)
2. **Contrastive Objective**: Image-text matching (vs. next-token prediction)
3. **Vision Processing**: Added ResNet encoder for image understanding
4. **Fixed-Length Embeddings**: Pooled representations (vs. sequential generation)
5. **Reduced Scale**: 180M parameters focused on representation (vs. 200M for generation)
## Use Cases
- **Zero-shot image classification**
- **Image-text retrieval**
- **Visual search engines**
- **Multimodal embedding generation**
- **Cross-modal understanding tasks**
## Limitations
1. **Training Scale**: Trained on 2,000 iterations with a single dataset
2. **Domain Specificity**: Optimized for Midjourney-style synthetic images
3. **Limited Context**: 77-token text limit (CLIP standard)
4. **No Fine-Grained Localization**: Global image embeddings only
5. **Checkpoint Availability**: Best model saved based on validation loss
## Ethical Considerations
- **Synthetic Training Data**: Model trained on AI-generated images may not generalize perfectly to real-world photos
- **Bias Propagation**: May inherit biases from the Midjourney dataset
- **Content Safety**: Should be evaluated for fairness across diverse demographics before deployment