pashaa commited on
Commit
0518d49
·
verified ·
1 Parent(s): 66b30c4

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,10 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
4
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
10
+ spiece.model filter=lfs diff=lfs merge=lfs -text
 
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ This model is a T5-base reranker fine-tuned on the MS MARCO passage dataset for 100k steps (or 10 epochs).
2
+
3
+ For better zero-shot performance (i.e., inference on other datasets), we recommend using `castorini/monot5-base-msmarco-10k`.
4
+
5
+ For more details on how to use it, check the following links:
6
+ - [A simple reranking example](https://github.com/castorini/pygaggle#a-simple-reranking-example)
7
+ - [Rerank MS MARCO passages](https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-passage-subset.md)
8
+ - [Rerank Robust04 documents](https://github.com/castorini/pygaggle/blob/master/docs/experiments-robust04-monot5-gpu.md)
9
+
10
+ Paper describing the model: [Document Ranking with a Pretrained Sequence-to-Sequence Model](https://www.aclweb.org/anthology/2020.findings-emnlp.63/)
config.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_num_labels": 2,
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "bos_token_id": null,
7
+ "d_ff": 3072,
8
+ "d_kv": 64,
9
+ "d_model": 768,
10
+ "decoder_start_token_id": 0,
11
+ "do_sample": false,
12
+ "dropout_rate": 0.1,
13
+ "early_stopping": false,
14
+ "eos_token_id": 1,
15
+ "finetuning_task": null,
16
+ "id2label": {
17
+ "0": "LABEL_0",
18
+ "1": "LABEL_1"
19
+ },
20
+ "initializer_factor": 1.0,
21
+ "is_decoder": false,
22
+ "is_encoder_decoder": true,
23
+ "label2id": {
24
+ "LABEL_0": 0,
25
+ "LABEL_1": 1
26
+ },
27
+ "layer_norm_epsilon": 1e-06,
28
+ "length_penalty": 1.0,
29
+ "max_length": 20,
30
+ "min_length": 0,
31
+ "model_type": "t5",
32
+ "n_positions": 512,
33
+ "no_repeat_ngram_size": 0,
34
+ "num_beams": 1,
35
+ "num_heads": 12,
36
+ "num_layers": 12,
37
+ "num_return_sequences": 1,
38
+ "output_attentions": false,
39
+ "output_hidden_states": false,
40
+ "output_past": true,
41
+ "pad_token_id": 0,
42
+ "prefix": null,
43
+ "pruned_heads": {},
44
+ "relative_attention_num_buckets": 32,
45
+ "repetition_penalty": 1.0,
46
+ "task_specific_params": {
47
+ "summarization": {
48
+ "early_stopping": true,
49
+ "length_penalty": 2.0,
50
+ "max_length": 200,
51
+ "min_length": 30,
52
+ "no_repeat_ngram_size": 3,
53
+ "num_beams": 4,
54
+ "prefix": "summarize: "
55
+ },
56
+ "translation_en_to_de": {
57
+ "early_stopping": true,
58
+ "max_length": 300,
59
+ "num_beams": 4,
60
+ "prefix": "translate English to German: "
61
+ },
62
+ "translation_en_to_fr": {
63
+ "early_stopping": true,
64
+ "max_length": 300,
65
+ "num_beams": 4,
66
+ "prefix": "translate English to French: "
67
+ },
68
+ "translation_en_to_ro": {
69
+ "early_stopping": true,
70
+ "max_length": 300,
71
+ "num_beams": 4,
72
+ "prefix": "translate English to Romanian: "
73
+ }
74
+ },
75
+ "temperature": 1.0,
76
+ "top_k": 50,
77
+ "top_p": 1.0,
78
+ "torchscript": false,
79
+ "use_bfloat16": false,
80
+ "vocab_size": 32128
81
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd58affd5786e328e6a1afadc39cc33d63e2a8f111bbbccc69212395c2f38592
3
+ size 891625348
handler.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom handler for MonoT5 reranking on HuggingFace Inference Endpoints.
3
+
4
+ Returns relevance probability scores for query-document pairs.
5
+ """
6
+
7
+ import math
8
+ from typing import Any, Dict, List
9
+
10
+ import torch
11
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
12
+
13
+
14
+ class EndpointHandler:
15
+ """Handler for MonoT5 relevance scoring."""
16
+
17
+ def __init__(self, path: str = ""):
18
+ """Initialize the model and tokenizer."""
19
+ self.tokenizer = T5Tokenizer.from_pretrained(path)
20
+ self.model = T5ForConditionalGeneration.from_pretrained(path)
21
+ self.model.eval()
22
+
23
+ # Move to GPU if available
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.model = self.model.to(self.device)
26
+
27
+ # Get token IDs for "true" and "false"
28
+ self.true_id = self.tokenizer.encode("true", add_special_tokens=False)[0]
29
+ self.false_id = self.tokenizer.encode("false", add_special_tokens=False)[0]
30
+
31
+ print(f"MonoT5 loaded on {self.device}")
32
+
33
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
34
+ """
35
+ Process inference requests.
36
+
37
+ Accepts either:
38
+ - {"inputs": "Query: ... Document: ... Relevant:"} - single input
39
+ - {"inputs": ["Query: ... Document: ... Relevant:", ...]} - batch
40
+ - {"query": "...", "documents": ["...", ...]} - structured input
41
+
42
+ Returns:
43
+ - List of {"score": float, "label": "true"/"false"} dicts
44
+ """
45
+ inputs = data.get("inputs", data)
46
+
47
+ # Handle structured input format
48
+ if "query" in data and "documents" in data:
49
+ query = data["query"]
50
+ documents = data["documents"]
51
+ inputs = [
52
+ f"Query: {query} Document: {doc} Relevant:"
53
+ for doc in documents
54
+ ]
55
+
56
+ # Ensure inputs is a list
57
+ if isinstance(inputs, str):
58
+ inputs = [inputs]
59
+
60
+ # Score all inputs
61
+ results = []
62
+ for input_text in inputs:
63
+ score = self._score_single(input_text)
64
+ results.append({
65
+ "score": score,
66
+ "label": "true" if score > 0.5 else "false"
67
+ })
68
+
69
+ return results
70
+
71
+ def _score_single(self, input_text: str) -> float:
72
+ """Score a single query-document pair."""
73
+ # Tokenize
74
+ inputs = self.tokenizer(
75
+ input_text,
76
+ return_tensors="pt",
77
+ max_length=512,
78
+ truncation=True,
79
+ padding=True
80
+ ).to(self.device)
81
+
82
+ # Get logits for first generated token
83
+ with torch.no_grad():
84
+ decoder_input_ids = torch.tensor(
85
+ [[self.tokenizer.pad_token_id]],
86
+ device=self.device
87
+ )
88
+ outputs = self.model(
89
+ **inputs,
90
+ decoder_input_ids=decoder_input_ids
91
+ )
92
+ logits = outputs.logits[0, -1, :]
93
+
94
+ # Get probabilities for true/false tokens
95
+ true_logit = logits[self.true_id].item()
96
+ false_logit = logits[self.false_id].item()
97
+
98
+ # Softmax over true/false
99
+ max_logit = max(true_logit, false_logit)
100
+ true_prob = math.exp(true_logit - max_logit)
101
+ false_prob = math.exp(false_logit - max_logit)
102
+
103
+ return true_prob / (true_prob + false_prob)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64467f69fc891a29b35b386b7d66e4a3cdb2285588dcc85b56c396eb3a31b398
3
+ size 891691413
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers>=4.30.0
2
+ torch>=2.0.0
3
+ sentencepiece
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 100, "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"], "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "t5-base"}