# Summary Classifier (with Gemma Embeddings) This repository contains a custom PyTorch model for summary classification. It uses `google/embeddinggemma-300m` for embeddings and a custom classification head. ## How to Use First, install the required libraries: ```bash pip install -r requirements.txt ``` You can then load and use the model with the following Python code. The code will automatically download the model files from this repository. ```python import torch from sentence_transformers import SentenceTransformer from huggingface_hub import hf_hub_download from model import SummaryClassifier REPO_ID = "Prahaladha/summary-gemma-classifier" WEIGHTS_FILE = "summary_classifier_gemma.pth" print(f"Downloading model from {REPO_ID}...") model_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE) checkpoint = torch.load(model_path, map_location=torch.device('cpu')) embedder_name = checkpoint['embedder_name'] num_classes = checkpoint['num_classes'] dropout = checkpoint.get('dropout', 0.1) print(f"Loading model with embedder: {embedder_name}") print(f"Number of classes: {num_classes}") embedder = SentenceTransformer(embedder_name) model = SummaryClassifier( embedder=embedder, num_classes=num_classes, dropout=dropout ) model.head.load_state_dict(checkpoint['head_state_dict']) model.eval() device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) print(f"Model loaded and moved to {device}") test_summaries = ["A concise and accurate recap.", "This was a long and winding explanation."] with torch.no_grad(): logits = model(test_summaries) probs = torch.softmax(logits, dim=-1) predicted_class = torch.argmax(probs, dim=-1) print("\n--- Inference Test ---") print(f"Input: {test_summaries}") print(f"Probs: {probs.cpu().numpy()}") print(f"Predicted: {predicted_class.cpu().numpy()}") ```