Khriis commited on
Commit
d37eeb3
·
1 Parent(s): 38ce35c

Added handler

Browse files
Files changed (4) hide show
  1. config.json +4 -0
  2. cross_scorer_model.py +0 -1
  3. handler.py +60 -0
  4. requirements.txt +2 -0
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "task": "text-classification",
3
+ "custom_handler": "handler.py"
4
+ }
cross_scorer_model.py CHANGED
@@ -10,7 +10,6 @@ from transformers import BertForMaskedLM
10
 
11
  import torch.nn.functional as F
12
 
13
- import spacy
14
  import transformers
15
  import torch.nn as nn
16
 
 
10
 
11
  import torch.nn.functional as F
12
 
 
13
  import transformers
14
  import torch.nn as nn
15
 
handler.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import importlib.util
3
+ import sys
4
+ import pathlib
5
+ from transformers import AutoModel, AutoTokenizer
6
+
7
+ class InferenceHandler:
8
+ def __init__(self):
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Import custom model definition from local file
12
+ model_path = "cross_scorer_model.py"
13
+ spec = importlib.util.spec_from_file_location("cross_scorer_model", model_path)
14
+ mod = importlib.util.module_from_spec(spec)
15
+ sys.modules["cross_scorer_model"] = mod
16
+ spec.loader.exec_module(mod)
17
+
18
+ # Initialize encoder and custom model
19
+ encoder = AutoModel.from_pretrained("roberta-base", add_pooling_layer=False)
20
+ self.model = mod.CrossScorerCrossEncoder(encoder).to(self.device)
21
+
22
+ # Load weights
23
+ weights_path = "reflection_scorer_weight.pt"
24
+ state = torch.load(weights_path, map_location=self.device)
25
+ sd = state.get("model_state_dict", state)
26
+ self.model.load_state_dict(sd, strict=False)
27
+
28
+ self.model.eval()
29
+
30
+ # Initialize tokenizer
31
+ self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")
32
+
33
+ def handle(self, inputs: list) -> list:
34
+ results = []
35
+ for item in inputs:
36
+ prompt = item.get("prompt")
37
+ response = item.get("response")
38
+
39
+ if not prompt or not response:
40
+ # Handle missing keys gracefully, though instructions imply strict format
41
+ results.append({"error": "Missing prompt or response"})
42
+ continue
43
+
44
+ # Preprocessing
45
+ batch = self.tokenizer(
46
+ prompt,
47
+ response,
48
+ padding="longest",
49
+ truncation=True,
50
+ return_tensors="pt"
51
+ ).to(self.device)
52
+
53
+ # Inference
54
+ with torch.no_grad():
55
+ # score_forward returns raw logits (based on README/code usage), we need sigmoid
56
+ score = self.model.score_forward(**batch).sigmoid().item()
57
+
58
+ results.append({"score": round(score, 4)})
59
+
60
+ return results
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch