|
|
--- |
|
|
license: apache-2.0 |
|
|
pipeline_tag: zero-shot-image-classification |
|
|
datasets: |
|
|
- MohamedRashad/midjourney-detailed-prompts |
|
|
--- |
|
|
|
|
|
# [i3-CLIP](https://github.com/FlameF0X/open-i3/tree/main/src/zero-shot-image-classification/CLIP): Hybrid RWKV-Attention Vision-Language Model |
|
|
|
|
|
## Model Description |
|
|
|
|
|
**i3-CLIP** is a 180M parameter vision-language model that combines contrastive learning with a novel hybrid architecture. Built upon the foundation of i3-Ethan (SDPA - a 200M parameter text generator), i3-CLIP adapts the hybrid RWKV-Attention approach for multimodal understanding, enabling efficient image-text matching and representation learning. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Model Type**: Vision-Language Model (CLIP-style) |
|
|
- **Total Parameters**: 180,150,081 (180M) |
|
|
- **Base Architecture**: Modified i3-Ethan (a SDPA version of Ethan) hybrid architecture |
|
|
- **Training Data**: Midjourney Detailed Prompts dataset |
|
|
- **Model Dimensions**: 768-dimensional embeddings |
|
|
- **Architecture Components**: |
|
|
- **Vision Encoder**: ResNet-based CNN (16-layer deep) |
|
|
- **Text Encoder**: Hybrid RWKV + Transformer (12 RWKV layers + 4 Attention layers) |
|
|
- **Contrastive Learning**: Learnable temperature-scaled similarity |
|
|
|
|
|
### Architecture Overview |
|
|
|
|
|
``` |
|
|
Vision Encoder (CNN-based): |
|
|
βββ Stem: Conv7x7 β BatchNorm β ReLU β MaxPool |
|
|
βββ Layer1: 3x ResBlocks (64 channels) |
|
|
βββ Layer2: 4x ResBlocks (128 channels) |
|
|
βββ Layer3: 6x ResBlocks (256 channels) |
|
|
βββ Layer4: 3x ResBlocks (512 channels) |
|
|
βββ AdaptiveAvgPool β Linear(512 β 768) |
|
|
|
|
|
Text Encoder (Hybrid): |
|
|
βββ Token Embedding (vocab_size β 768) |
|
|
βββ Positional Embedding (77 positions) |
|
|
βββ 12x RWKV Blocks (parallel linear attention) |
|
|
βββ 4x Transformer Blocks (12 heads, multi-head attention) |
|
|
βββ Layer Norm β Take last token |
|
|
|
|
|
Output: |
|
|
βββ Normalized embeddings β Contrastive loss |
|
|
``` |
|
|
|
|
|
## Key Features |
|
|
|
|
|
### 1. **Hybrid Text Processing** |
|
|
- **RWKV Layers (12)**: Efficient linear-time attention for capturing sequential dependencies |
|
|
- **Transformer Layers (4)**: Full self-attention for complex reasoning |
|
|
- **Advantage**: Combines O(n) efficiency of RWKV with the expressiveness of transformers |
|
|
|
|
|
### 2. **JIT-Optimized RWKV** |
|
|
- Custom `torch.jit.script` implementation of parallel RWKV attention |
|
|
- Efficient time-mixing and channel-mixing mechanisms |
|
|
- Learnable decay rates and position-dependent mixing |
|
|
|
|
|
### 3. **Contrastive Learning** |
|
|
- Bidirectional image-text matching |
|
|
- Learnable temperature parameter (initialized to log(1/0.07)) |
|
|
- Symmetric cross-entropy loss |
|
|
|
|
|
### 4. **Efficient Vision Encoding** |
|
|
- ResNet-inspired architecture with 16 residual blocks |
|
|
- Progressive channel expansion (64 β 128 β 256 β 512) |
|
|
- Batch normalization and residual connections for stable training |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Framework**: PyTorch with mixed precision training |
|
|
- **Optimizer**: AdamW (lr=5e-5) |
|
|
- **Batch Size**: 32 |
|
|
- **Sequence Length**: 77 tokens (CLIP standard) |
|
|
- **Image Size**: 224Γ224 |
|
|
- **Normalization**: ImageNet-style (mean=\[0.48, 0.45, 0.40], std=\[0.26, 0.26, 0.27]) |
|
|
- **Training Steps**: 2,000 iterations |
|
|
- **Validation**: Every 100 steps |
|
|
|
|
|
### Training Dataset |
|
|
|
|
|
**Midjourney Detailed Prompts**: A dataset of AI-generated images with rich, detailed text descriptions, providing high-quality image-text pairs for contrastive learning. |
|
|
|
|
|
## Relationship to i3-Ethan-SDPA |
|
|
|
|
|
i3-CLIP is derived from **i3-Ethan-SDPA**, a 200M parameter text generation model with the following shared characteristics: |
|
|
|
|
|
### Inherited from i3-Ethan-SDPA: |
|
|
1. **Hybrid RWKV-Attention Architecture**: The core design philosophy of combining RWKV blocks (efficient sequential processing) with standard attention layers (complex reasoning) |
|
|
2. **JIT-Optimized RWKV Implementation**: Same parallel linear attention kernel |
|
|
3. **Time-Mixing Mechanisms**: Token-wise interpolation between current and previous states |
|
|
4. **Channel-Mixing FFN**: Squared ReLU activation in feed-forward networks |
|
|
|
|
|
### Key Adaptations for Vision-Language Tasks: |
|
|
1. **Dual Encoder Design**: Separate vision and text pathways (vs. single autoregressive decoder) |
|
|
2. **Contrastive Objective**: Image-text matching (vs. next-token prediction) |
|
|
3. **Vision Processing**: Added ResNet encoder for image understanding |
|
|
4. **Fixed-Length Embeddings**: Pooled representations (vs. sequential generation) |
|
|
5. **Reduced Scale**: 180M parameters focused on representation (vs. 200M for generation) |
|
|
|
|
|
## Use Cases |
|
|
|
|
|
- **Zero-shot image classification** |
|
|
- **Image-text retrieval** |
|
|
- **Visual search engines** |
|
|
- **Multimodal embedding generation** |
|
|
- **Cross-modal understanding tasks** |
|
|
|
|
|
## Limitations |
|
|
|
|
|
1. **Training Scale**: Trained on 2,000 iterations with a single dataset |
|
|
2. **Domain Specificity**: Optimized for Midjourney-style synthetic images |
|
|
3. **Limited Context**: 77-token text limit (CLIP standard) |
|
|
4. **No Fine-Grained Localization**: Global image embeddings only |
|
|
5. **Checkpoint Availability**: Best model saved based on validation loss |
|
|
|
|
|
## Ethical Considerations |
|
|
|
|
|
- **Synthetic Training Data**: Model trained on AI-generated images may not generalize perfectly to real-world photos |
|
|
- **Bias Propagation**: May inherit biases from the Midjourney dataset |
|
|
- **Content Safety**: Should be evaluated for fairness across diverse demographics before deployment |