YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Summary Classifier (with Gemma Embeddings)
This repository contains a custom PyTorch model for summary classification. It uses google/embeddinggemma-300m for embeddings and a custom classification head.
How to Use
First, install the required libraries:
pip install -r requirements.txt
You can then load and use the model with the following Python code. The code will automatically download the model files from this repository.
import torch
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from model import SummaryClassifier
REPO_ID = "Prahaladha/summary-gemma-classifier"
WEIGHTS_FILE = "summary_classifier_gemma.pth"
print(f"Downloading model from {REPO_ID}...")
model_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE)
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
embedder_name = checkpoint['embedder_name']
num_classes = checkpoint['num_classes']
dropout = checkpoint.get('dropout', 0.1)
print(f"Loading model with embedder: {embedder_name}")
print(f"Number of classes: {num_classes}")
embedder = SentenceTransformer(embedder_name)
model = SummaryClassifier(
embedder=embedder,
num_classes=num_classes,
dropout=dropout
)
model.head.load_state_dict(checkpoint['head_state_dict'])
model.eval()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
print(f"Model loaded and moved to {device}")
test_summaries = ["A concise and accurate recap.", "This was a long and winding explanation."]
with torch.no_grad():
logits = model(test_summaries)
probs = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probs, dim=-1)
print("\n--- Inference Test ---")
print(f"Input: {test_summaries}")
print(f"Probs: {probs.cpu().numpy()}")
print(f"Predicted: {predicted_class.cpu().numpy()}")
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support