update pooling for ST
Browse files- .DS_Store +0 -0
- 1_Pooling/config.json +10 -0
- README.md +2 -2
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
1_Pooling/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"word_embedding_dimension": 1024,
|
| 3 |
+
"pooling_mode_cls_token": true,
|
| 4 |
+
"pooling_mode_mean_tokens": false,
|
| 5 |
+
"pooling_mode_max_tokens": false,
|
| 6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
| 7 |
+
"pooling_mode_weightedmean_tokens": false,
|
| 8 |
+
"pooling_mode_lasttoken": false,
|
| 9 |
+
"include_prompt": true
|
| 10 |
+
}
|
README.md
CHANGED
|
@@ -73,14 +73,14 @@ from sentence_transformers import SentenceTransformer
|
|
| 73 |
from sentence_transformers.util import cos_sim
|
| 74 |
|
| 75 |
sentences = [
|
|
|
|
| 76 |
"def quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)",
|
| 77 |
"def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr",
|
| 78 |
-
"how to implement quick sort in Python?"
|
| 79 |
]
|
| 80 |
|
| 81 |
model = SentenceTransformer('Salesforce/SFR-Embedding-Code-400M_R', trust_remote_code=True)
|
| 82 |
embeddings = model.encode(sentences)
|
| 83 |
-
print(cos_sim(embeddings[0], embeddings[1]))
|
| 84 |
```
|
| 85 |
|
| 86 |
### Citation
|
|
|
|
| 73 |
from sentence_transformers.util import cos_sim
|
| 74 |
|
| 75 |
sentences = [
|
| 76 |
+
"how to implement quick sort in Python?",
|
| 77 |
"def quick_sort(arr):\n if len(arr) <= 1:\n return arr\n pivot = arr[len(arr) // 2]\n left = [x for x in arr if x < pivot]\n middle = [x for x in arr if x == pivot]\n right = [x for x in arr if x > pivot]\n return quick_sort(left) + middle + quick_sort(right)",
|
| 78 |
"def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr",
|
|
|
|
| 79 |
]
|
| 80 |
|
| 81 |
model = SentenceTransformer('Salesforce/SFR-Embedding-Code-400M_R', trust_remote_code=True)
|
| 82 |
embeddings = model.encode(sentences)
|
| 83 |
+
print(cos_sim(embeddings[0], embeddings[1:]))
|
| 84 |
```
|
| 85 |
|
| 86 |
### Citation
|