File size: 1,720 Bytes
4a4866a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
export function applyRepPenalty(
  logits: Float32Array,
  generatedTokens: number[],
  penalty: number,
): Float32Array {
  const next = new Float32Array(logits)
  for (const token of new Set(generatedTokens)) {
    next[token] = next[token] > 0 ? next[token] / penalty : next[token] * penalty
  }
  return next
}

export function applyMinP(logits: Float32Array, minP: number): Float32Array {
  const next = new Float32Array(logits)
  let maxLogit = -Infinity
  for (const value of next) {
    if (value > maxLogit) {
      maxLogit = value
    }
  }

  const probs = new Float64Array(next.length)
  let total = 0
  let maxProb = 0
  for (let index = 0; index < next.length; index += 1) {
    const prob = Math.exp(next[index] - maxLogit)
    probs[index] = prob
    total += prob
    if (prob > maxProb) {
      maxProb = prob
    }
  }

  const threshold = (maxProb / total) * minP
  for (let index = 0; index < next.length; index += 1) {
    if (probs[index] / total < threshold) {
      next[index] = -1e9
    }
  }

  return next
}

export function sampleWithTemperature(
  logits: Float32Array,
  temperature: number,
): number {
  let maxLogit = -Infinity
  for (const value of logits) {
    if (value > maxLogit) {
      maxLogit = value
    }
  }

  const probs = new Float64Array(logits.length)
  let total = 0
  for (let index = 0; index < logits.length; index += 1) {
    const prob = Math.exp(logits[index] / temperature - maxLogit / temperature)
    probs[index] = prob
    total += prob
  }

  let threshold = Math.random() * total
  for (let index = 0; index < probs.length; index += 1) {
    threshold -= probs[index]
    if (threshold <= 0) {
      return index
    }
  }

  return probs.length - 1
}