|
|
--- |
|
|
library_name: pytorch |
|
|
tags: |
|
|
- contrastive-learning |
|
|
- tag-classification |
|
|
- semantic-search |
|
|
- embeddings |
|
|
- persona-conditioned |
|
|
- pretrained-backbone |
|
|
--- |
|
|
|
|
|
# embeddinggemma-300m-tag-classification |
|
|
|
|
|
This is a **pretrained backbone model** (google/embeddinggemma-300m) used for tag classification via contrastive learning. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model uses the `google/embeddinggemma-300m` backbone directly without fine-tuning. It's designed for zero-shot tag classification tasks where you want to use a pretrained embedding model for semantic similarity computation. |
|
|
|
|
|
## Usage |
|
|
|
|
|
See the README.md for detailed usage examples using our module abstractions. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
- **Backbone**: `google/embeddinggemma-300m` |
|
|
- **Type**: Pretrained backbone (no fine-tuning) |
|
|
- **Embedding Dimension**: Varies by backbone model |
|
|
|
|
|
## Usage Example |
|
|
|
|
|
```python |
|
|
""" |
|
|
Example: Using EmbeddingGemma-300m for Tag Classification |
|
|
|
|
|
This example shows how to use the pretrained EmbeddingGemma-300m backbone |
|
|
for zero-shot tag classification using our module abstractions. |
|
|
|
|
|
Installation: |
|
|
pip install git+https://github.com/Pieces/TAG-module.git@main |
|
|
# Or: pip install -e . |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from tags_model.models.backbone import SharedTextBackbone |
|
|
from playground.validate_from_checkpoint import compute_ranked_tags |
|
|
|
|
|
# Load the pretrained backbone |
|
|
print("Loading EmbeddingGemma-300m...") |
|
|
backbone = SharedTextBackbone( |
|
|
model_name="google/embeddinggemma-300m", |
|
|
embedding_dim=768, |
|
|
freeze_backbone=True, |
|
|
pooling_mode="mean", |
|
|
) |
|
|
backbone.eval() |
|
|
print("✓ Model loaded!") |
|
|
|
|
|
# Example query |
|
|
query_text = "How to implement OAuth2 authentication in a Python Flask API?" |
|
|
|
|
|
# Candidate tags to rank |
|
|
candidate_tags = [ |
|
|
"python", "flask", "oauth2", "authentication", "api", |
|
|
"security", "web-development", "jwt", "rest-api", "backend" |
|
|
] |
|
|
|
|
|
print(f"\nQuery: {query_text}") |
|
|
print(f"Candidate tags: {candidate_tags}\n") |
|
|
|
|
|
# Encode query and tags |
|
|
with torch.inference_mode(): |
|
|
query_emb = backbone.encode_texts([query_text], max_length=2048, return_dict=False)[0] |
|
|
tag_embs = backbone.encode_texts(candidate_tags, max_length=2048, return_dict=False) |
|
|
|
|
|
print(f"Query embedding shape: {query_emb.shape}") |
|
|
print(f"Tag embeddings shape: {tag_embs.shape}") |
|
|
|
|
|
# Rank tags by similarity |
|
|
ranked_tags = compute_ranked_tags( |
|
|
query_emb=query_emb, |
|
|
pos_embs=torch.empty(0, 768), # No positives for zero-shot |
|
|
neg_embs=torch.empty(0, 768), # No negatives for zero-shot |
|
|
general_embs=tag_embs, |
|
|
positive_tags=[], |
|
|
negative_tags=[], |
|
|
general_tags=candidate_tags, |
|
|
) |
|
|
|
|
|
# Display top-ranked tags |
|
|
print("\n" + "="*60) |
|
|
print("Top Ranked Tags:") |
|
|
print("="*60) |
|
|
for tag, rank, label, score in ranked_tags[:5]: |
|
|
print(f"{rank:2d}. {tag:20s} (score: {score:.4f})") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Example complete!") |
|
|
|
|
|
|
|
|
``` |
|
|
|
|
|
### Running the Example |
|
|
|
|
|
```bash |
|
|
# Install the repository first |
|
|
pip install git+https://github.com/Pieces/TAG-module.git@main |
|
|
# Or for local development: |
|
|
pip install -e . |
|
|
|
|
|
# Run the example |
|
|
python embeddinggemma_example.py |
|
|
``` |
|
|
|
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@software{{tag_module, |
|
|
title = {{TAG Module: Persona-Conditioned Contrastive Learning for Tag Classification}}, |
|
|
author = {{Your Name}}, |
|
|
year = {{2025}}, |
|
|
url = {{https://github.com/yourusername/tag-module}} |
|
|
}} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
Please refer to the original model license for the backbone model. |
|
|
|