Upload train_grpo.py with huggingface_hub
Browse files- train_grpo.py +1 -1
train_grpo.py
CHANGED
|
@@ -158,7 +158,7 @@ class QMDRewardFunction:
|
|
| 158 |
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
| 159 |
print("Embedding model loaded.")
|
| 160 |
|
| 161 |
-
def __call__(self, completions: list[str], prompts: list[str] = None) -> list[float]:
|
| 162 |
"""Compute rewards for a batch of completions."""
|
| 163 |
rewards = []
|
| 164 |
|
|
|
|
| 158 |
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
| 159 |
print("Embedding model loaded.")
|
| 160 |
|
| 161 |
+
def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
|
| 162 |
"""Compute rewards for a batch of completions."""
|
| 163 |
rewards = []
|
| 164 |
|