nickcdryan commited on
Commit
714cdca
·
verified ·
1 Parent(s): 8bc2bec

Upload Standard InfoNCE retrieval model

Browse files
Files changed (3) hide show
  1. README.md +104 -0
  2. config.json +25 -0
  3. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: google-bert/bert-base-uncased
4
+ tags:
5
+ - retrieval
6
+ - information-retrieval
7
+ - sentence-transformers
8
+ - bert
9
+ - msmarco
10
+ - squad
11
+ pipeline_tag: feature-extraction
12
+ ---
13
+
14
+ # nickcdryan/bitter-retrieval-standard-infonce-bert
15
+
16
+ This is a retrieval model fine-tuned using **Standard InfoNCE** on MS MARCO dataset with additional validation on SQuAD.
17
+
18
+ ## Model Details
19
+
20
+ - **Base Model**: google-bert/bert-base-uncased
21
+ - **Training Method**: Standard InfoNCE
22
+ - **Training Data**: MS MARCO soft-labeled dataset
23
+ - **Validation Data**: SQuAD v2 + MS MARCO
24
+ - **Framework**: PyTorch + Transformers
25
+
26
+ ## Training Details
27
+
28
+ This model was trained using the bitter-retrieval framework with:
29
+
30
+ - **Training Method**: `Standard InfoNCE`
31
+ - **Encoder**: BERT-base-uncased
32
+ - **Max Sequence Length**: 512 tokens
33
+ - **Batch Size**: 32
34
+ - **Epochs**: 2
35
+ - **Learning Rate**: 2e-5
36
+ - **Temperature**: 0.02
37
+
38
+ ## Usage
39
+
40
+ ```python
41
+ from transformers import AutoModel, AutoTokenizer
42
+ import torch
43
+ import torch.nn.functional as F
44
+
45
+ # Load model and tokenizer
46
+ model = AutoModel.from_pretrained("nickcdryan/bitter-retrieval-standard-infonce-bert")
47
+ tokenizer = AutoTokenizer.from_pretrained("nickcdryan/bitter-retrieval-standard-infonce-bert")
48
+
49
+ def encode_text(text, prefix=""):
50
+ '''Encode text with optional prefix'''
51
+ full_text = f"{prefix}{text}" if prefix else text
52
+ inputs = tokenizer(full_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
53
+
54
+ with torch.no_grad():
55
+ outputs = model(**inputs)
56
+ # Mean pooling
57
+ attention_mask = inputs['attention_mask']
58
+ token_embeddings = outputs.last_hidden_state
59
+ masked_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
60
+ sum_embeddings = masked_embeddings.sum(dim=1)
61
+ count_tokens = attention_mask.sum(dim=1, keepdim=True)
62
+ embeddings = sum_embeddings / count_tokens
63
+ # L2 normalize
64
+ embeddings = F.normalize(embeddings, dim=-1)
65
+
66
+ return embeddings
67
+
68
+ # Example usage
69
+ query = "What is machine learning?"
70
+ passage = "Machine learning is a subset of artificial intelligence..."
71
+
72
+ # Encode with prefixes (recommended)
73
+ query_emb = encode_text(query, "query: ")
74
+ passage_emb = encode_text(passage, "passage: ")
75
+
76
+ # Compute similarity
77
+ similarity = torch.cosine_similarity(query_emb, passage_emb)
78
+ print(f"Similarity: {similarity.item():.4f}")
79
+ ```
80
+
81
+ ## Evaluation Metrics
82
+
83
+ The model was evaluated on both SQuAD and MS MARCO datasets with the following metrics:
84
+ - **Retrieval Accuracy**: How often the correct passage is retrieved
85
+ - **F1 Score**: Token-level F1 between generated and reference answers
86
+ - **Exact Match**: Exact match between generated and reference answers
87
+ - **LLM Judge**: Semantic similarity judged by Gemini-2.0-flash
88
+
89
+ ## Training Framework
90
+
91
+ This model was trained using the [bitter-retrieval](https://github.com/yourusername/bitter-retrieval) framework, which implements various contrastive learning methods for retrieval tasks.
92
+
93
+ ## Citation
94
+
95
+ If you use this model, please cite:
96
+
97
+ ```bibtex
98
+ @misc{bitter-retrieval-standard infonce,
99
+ title={Bitter Retrieval: Standard InfoNCE Fine-tuned BERT for Information Retrieval},
100
+ author={Your Name},
101
+ year={2024},
102
+ howpublished={\url{https://huggingface.co/nickcdryan/bitter-retrieval-standard-infonce-bert}}
103
+ }
104
+ ```
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 3072,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 12,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.53.2",
22
+ "type_vocab_size": 2,
23
+ "use_cache": true,
24
+ "vocab_size": 30522
25
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fc83491dbe1b924e899d6f2d62783ede2a7762cb2a7b479f9b97ec8c9988190
3
+ size 437951328