Prahaladha commited on
Commit
e4287dc
·
verified ·
1 Parent(s): 315601a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +63 -0
README.md CHANGED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Summary Classifier (with Gemma Embeddings)
3
+
4
+ This repository contains a custom PyTorch model for summary classification. It uses `google/embeddinggemma-300m` for embeddings and a custom classification head.
5
+
6
+ ## How to Use
7
+
8
+ First, install the required libraries:
9
+
10
+ ```bash
11
+ pip install -r requirements.txt
12
+ ```
13
+
14
+ You can then load and use the model with the following Python code. The code will automatically download the model files from this repository.
15
+
16
+ ```python
17
+ import torch
18
+ from sentence_transformers import SentenceTransformer
19
+ from huggingface_hub import hf_hub_download
20
+ from model import SummaryClassifier
21
+
22
+ REPO_ID = "Prahaladha/summary-gemma-classifier"
23
+ WEIGHTS_FILE = "summary_classifier_gemma.pth"
24
+
25
+ print(f"Downloading model from {REPO_ID}...")
26
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE)
27
+
28
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
29
+
30
+ embedder_name = checkpoint['embedder_name']
31
+ num_classes = checkpoint['num_classes']
32
+ dropout = checkpoint.get('dropout', 0.1)
33
+
34
+ print(f"Loading model with embedder: {embedder_name}")
35
+ print(f"Number of classes: {num_classes}")
36
+
37
+
38
+ embedder = SentenceTransformer(embedder_name)
39
+
40
+ model = SummaryClassifier(
41
+ embedder=embedder,
42
+ num_classes=num_classes,
43
+ dropout=dropout
44
+ )
45
+
46
+ model.head.load_state_dict(checkpoint['head_state_dict'])
47
+ model.eval()
48
+
49
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ model.to(device)
51
+ print(f"Model loaded and moved to {device}")
52
+
53
+ test_summaries = ["A concise and accurate recap.", "This was a long and winding explanation."]
54
+ with torch.no_grad():
55
+ logits = model(test_summaries)
56
+ probs = torch.softmax(logits, dim=-1)
57
+ predicted_class = torch.argmax(probs, dim=-1)
58
+
59
+ print("\n--- Inference Test ---")
60
+ print(f"Input: {test_summaries}")
61
+ print(f"Probs: {probs.cpu().numpy()}")
62
+ print(f"Predicted: {predicted_class.cpu().numpy()}")
63
+ ```