|
|
--- |
|
|
library_name: pytorch |
|
|
tags: |
|
|
- contrastive-learning |
|
|
- tag-classification |
|
|
- semantic-search |
|
|
- embeddings |
|
|
- persona-conditioned |
|
|
- pretrained-backbone |
|
|
--- |
|
|
|
|
|
# modernbert-base-tag-classification |
|
|
|
|
|
This is a **pretrained backbone model** (answerdotai/ModernBERT-base) used for tag classification via contrastive learning. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model uses the `answerdotai/ModernBERT-base` backbone directly without fine-tuning. It's designed for zero-shot tag classification tasks where you want to use a pretrained embedding model for semantic similarity computation. |
|
|
|
|
|
## Usage |
|
|
|
|
|
See the README.md for detailed usage examples using our module abstractions. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
- **Backbone**: `answerdotai/ModernBERT-base` |
|
|
- **Type**: Pretrained backbone (no fine-tuning) |
|
|
- **Embedding Dimension**: Varies by backbone model |
|
|
|
|
|
## Usage Example |
|
|
|
|
|
```python |
|
|
""" |
|
|
Example: Using ModernBERT-base for Tag Classification |
|
|
|
|
|
This example shows how to use the pretrained ModernBERT-base backbone |
|
|
for zero-shot tag classification using our module abstractions. |
|
|
|
|
|
Installation: |
|
|
pip install git+https://github.com/Pieces/TAG-module.git@main |
|
|
# Or: pip install -e . |
|
|
|
|
|
Note: ModernBERT requires Python < 3.14 due to torch.compile compatibility. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from tags_model.models.backbone import SharedTextBackbone |
|
|
from playground.validate_from_checkpoint import compute_ranked_tags |
|
|
|
|
|
# Load the pretrained backbone |
|
|
print("Loading ModernBERT-base...") |
|
|
backbone = SharedTextBackbone( |
|
|
model_name="answerdotai/ModernBERT-base", |
|
|
embedding_dim=768, |
|
|
freeze_backbone=True, |
|
|
pooling_mode="cls", |
|
|
trust_remote_code=True, # Required for ModernBERT |
|
|
) |
|
|
backbone.eval() |
|
|
print("✓ Model loaded!") |
|
|
|
|
|
# Example query |
|
|
query_text = "Machine learning model for image classification using PyTorch" |
|
|
|
|
|
# Candidate tags to rank |
|
|
candidate_tags = [ |
|
|
"pytorch", "machine-learning", "deep-learning", "computer-vision", |
|
|
"neural-networks", "cnn", "image-classification", "tensorflow", |
|
|
"data-science", "python" |
|
|
] |
|
|
|
|
|
print(f"\nQuery: {query_text}") |
|
|
print(f"Candidate tags: {candidate_tags}\n") |
|
|
|
|
|
# Encode query and tags |
|
|
with torch.inference_mode(): |
|
|
query_emb = backbone.encode_texts([query_text], max_length=512, return_dict=False)[0] |
|
|
tag_embs = backbone.encode_texts(candidate_tags, max_length=512, return_dict=False) |
|
|
|
|
|
print(f"Query embedding shape: {query_emb.shape}") |
|
|
print(f"Tag embeddings shape: {tag_embs.shape}") |
|
|
|
|
|
# Rank tags by similarity |
|
|
ranked_tags = compute_ranked_tags( |
|
|
query_emb=query_emb, |
|
|
pos_embs=torch.empty(0, 768), # No positives for zero-shot |
|
|
neg_embs=torch.empty(0, 768), # No negatives for zero-shot |
|
|
general_embs=tag_embs, |
|
|
positive_tags=[], |
|
|
negative_tags=[], |
|
|
general_tags=candidate_tags, |
|
|
) |
|
|
|
|
|
# Display top-ranked tags |
|
|
print("\n" + "="*60) |
|
|
print("Top Ranked Tags:") |
|
|
print("="*60) |
|
|
for tag, rank, label, score in ranked_tags[:5]: |
|
|
print(f"{rank:2d}. {tag:20s} (score: {score:.4f})") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Example complete!") |
|
|
|
|
|
|
|
|
``` |
|
|
|
|
|
### Running the Example |
|
|
|
|
|
```bash |
|
|
# Install the repository first |
|
|
pip install git+https://github.com/Pieces/TAG-module.git@main |
|
|
# Or for local development: |
|
|
pip install -e . |
|
|
|
|
|
# Run the example |
|
|
python modernbert_example.py |
|
|
``` |
|
|
|
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@software{{tag_module, |
|
|
title = {{TAG Module: Persona-Conditioned Contrastive Learning for Tag Classification}}, |
|
|
author = {{Your Name}}, |
|
|
year = {{2025}}, |
|
|
url = {{https://github.com/yourusername/tag-module}} |
|
|
}} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
Please refer to the original model license for the backbone model. |
|
|
|