RASMUS commited on
Commit
4a4866a
·
verified ·
1 Parent(s): e318987

Upload webapp/src/sampling.ts with huggingface_hub

Browse files
Files changed (1) hide show
  1. webapp/src/sampling.ts +72 -0
webapp/src/sampling.ts ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export function applyRepPenalty(
2
+ logits: Float32Array,
3
+ generatedTokens: number[],
4
+ penalty: number,
5
+ ): Float32Array {
6
+ const next = new Float32Array(logits)
7
+ for (const token of new Set(generatedTokens)) {
8
+ next[token] = next[token] > 0 ? next[token] / penalty : next[token] * penalty
9
+ }
10
+ return next
11
+ }
12
+
13
+ export function applyMinP(logits: Float32Array, minP: number): Float32Array {
14
+ const next = new Float32Array(logits)
15
+ let maxLogit = -Infinity
16
+ for (const value of next) {
17
+ if (value > maxLogit) {
18
+ maxLogit = value
19
+ }
20
+ }
21
+
22
+ const probs = new Float64Array(next.length)
23
+ let total = 0
24
+ let maxProb = 0
25
+ for (let index = 0; index < next.length; index += 1) {
26
+ const prob = Math.exp(next[index] - maxLogit)
27
+ probs[index] = prob
28
+ total += prob
29
+ if (prob > maxProb) {
30
+ maxProb = prob
31
+ }
32
+ }
33
+
34
+ const threshold = (maxProb / total) * minP
35
+ for (let index = 0; index < next.length; index += 1) {
36
+ if (probs[index] / total < threshold) {
37
+ next[index] = -1e9
38
+ }
39
+ }
40
+
41
+ return next
42
+ }
43
+
44
+ export function sampleWithTemperature(
45
+ logits: Float32Array,
46
+ temperature: number,
47
+ ): number {
48
+ let maxLogit = -Infinity
49
+ for (const value of logits) {
50
+ if (value > maxLogit) {
51
+ maxLogit = value
52
+ }
53
+ }
54
+
55
+ const probs = new Float64Array(logits.length)
56
+ let total = 0
57
+ for (let index = 0; index < logits.length; index += 1) {
58
+ const prob = Math.exp(logits[index] / temperature - maxLogit / temperature)
59
+ probs[index] = prob
60
+ total += prob
61
+ }
62
+
63
+ let threshold = Math.random() * total
64
+ for (let index = 0; index < probs.length; index += 1) {
65
+ threshold -= probs[index]
66
+ if (threshold <= 0) {
67
+ return index
68
+ }
69
+ }
70
+
71
+ return probs.length - 1
72
+ }