YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Summary Classifier (with Gemma Embeddings)

This repository contains a custom PyTorch model for summary classification. It uses google/embeddinggemma-300m for embeddings and a custom classification head.

How to Use

First, install the required libraries:

pip install -r requirements.txt

You can then load and use the model with the following Python code. The code will automatically download the model files from this repository.

import torch
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from model import SummaryClassifier 

REPO_ID = "Prahaladha/summary-gemma-classifier" 
WEIGHTS_FILE = "summary_classifier_gemma.pth"

print(f"Downloading model from {REPO_ID}...")
model_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE)

checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

embedder_name = checkpoint['embedder_name']
num_classes = checkpoint['num_classes']
dropout = checkpoint.get('dropout', 0.1) 

print(f"Loading model with embedder: {embedder_name}")
print(f"Number of classes: {num_classes}")


embedder = SentenceTransformer(embedder_name)

model = SummaryClassifier(
    embedder=embedder, 
    num_classes=num_classes, 
    dropout=dropout
)

model.head.load_state_dict(checkpoint['head_state_dict'])
model.eval()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
print(f"Model loaded and moved to {device}")

test_summaries = ["A concise and accurate recap.", "This was a long and winding explanation."]
with torch.no_grad():
    logits = model(test_summaries) 
    probs = torch.softmax(logits, dim=-1)
    predicted_class = torch.argmax(probs, dim=-1)
    
    print("\n--- Inference Test ---")
    print(f"Input: {test_summaries}")
    print(f"Probs: {probs.cpu().numpy()}")
    print(f"Predicted: {predicted_class.cpu().numpy()}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support