| # Summary Classifier (with Gemma Embeddings) | |
| This repository contains a custom PyTorch model for summary classification. It uses `google/embeddinggemma-300m` for embeddings and a custom classification head. | |
| ## How to Use | |
| First, install the required libraries: | |
| ```bash | |
| pip install -r requirements.txt | |
| ``` | |
| You can then load and use the model with the following Python code. The code will automatically download the model files from this repository. | |
| ```python | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import hf_hub_download | |
| from model import SummaryClassifier | |
| REPO_ID = "Prahaladha/summary-gemma-classifier" | |
| WEIGHTS_FILE = "summary_classifier_gemma.pth" | |
| print(f"Downloading model from {REPO_ID}...") | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE) | |
| checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | |
| embedder_name = checkpoint['embedder_name'] | |
| num_classes = checkpoint['num_classes'] | |
| dropout = checkpoint.get('dropout', 0.1) | |
| print(f"Loading model with embedder: {embedder_name}") | |
| print(f"Number of classes: {num_classes}") | |
| embedder = SentenceTransformer(embedder_name) | |
| model = SummaryClassifier( | |
| embedder=embedder, | |
| num_classes=num_classes, | |
| dropout=dropout | |
| ) | |
| model.head.load_state_dict(checkpoint['head_state_dict']) | |
| model.eval() | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| model.to(device) | |
| print(f"Model loaded and moved to {device}") | |
| test_summaries = ["A concise and accurate recap.", "This was a long and winding explanation."] | |
| with torch.no_grad(): | |
| logits = model(test_summaries) | |
| probs = torch.softmax(logits, dim=-1) | |
| predicted_class = torch.argmax(probs, dim=-1) | |
| print("\n--- Inference Test ---") | |
| print(f"Input: {test_summaries}") | |
| print(f"Probs: {probs.cpu().numpy()}") | |
| print(f"Predicted: {predicted_class.cpu().numpy()}") | |
| ``` |