File size: 6,260 Bytes
357ae2c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import Foundation
/// Segmentation algorithm. Port of `constrained_segmentation` in
/// `wtpsplit/utils/constraints.py`.
public enum SegmentationAlgorithm: String, Sendable, CaseIterable {
case viterbi
case greedy
}
/// Length-constrained segmentation by dynamic programming.
///
/// Maximizes `Σ log prior(segmentLength) + Σ log p(boundary)` subject to
/// `minLength ≤ segmentLength ≤ maxLength`. Returns boundary positions as 1-based
/// end indices in `[1, n)` (the terminal boundary `n` is excluded); callers form
/// `cuts = [0] + result + [n]` and slice.
enum ConstrainedSegmentation {
static func run(probs: [Double],
prior: (Int) -> Double,
minLength: Int,
maxLength: Int,
algorithm: SegmentationAlgorithm) -> [Int] {
let n = probs.count
if n == 0 { return [] }
switch algorithm {
case .greedy:
return greedy(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength)
case .viterbi:
return viterbi(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength)
}
}
// MARK: - Viterbi
private static func viterbi(probs: [Double], prior: (Int) -> Double,
n: Int, minLength: Int, maxLength: Int) -> [Int] {
let negInf = -Double.infinity
var dp = [Double](repeating: negInf, count: n + 1)
var back = [Int](repeating: 0, count: n + 1)
dp[0] = 0.0
for current in 1...n {
let earliest = max(0, current - maxLength) // segment ≤ maxLength
let latest = current - minLength // segment ≥ minLength
if latest < earliest { continue }
for segStart in earliest...latest {
let segLen = current - segStart
let priorProb = prior(segLen)
if priorProb <= 0 { continue }
if dp[segStart] == negInf { continue }
var candidate = dp[segStart] + log(priorProb)
if current < n {
// probs[i] = boundary after scalar i → boundary at `current` is probs[current-1].
candidate += log(probs[current - 1])
}
if candidate > dp[current] {
dp[current] = candidate
back[current] = segStart
}
}
}
if dp[n] == negInf {
return fallbackGreedy(n: n, minLength: minLength, maxLength: maxLength)
}
var indices = [Int]()
var pos = n
while pos > 0 {
indices.append(pos)
pos = back[pos]
}
indices.reverse()
if indices.last == n { indices.removeLast() }
return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength)
}
// MARK: - Greedy
private static func greedy(probs: [Double], prior: (Int) -> Double,
n: Int, minLength: Int, maxLength: Int) -> [Int] {
var indices = [Int]()
var current = 0
while current < n {
var bestScore = -Double.infinity
var bestEnd = -1
let startSearch = current + minLength
let endSearch = min(current + maxLength + 1, n + 1)
if startSearch >= endSearch {
let remaining = n - current
if remaining < minLength && !indices.isEmpty {
let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0
if n - newLast <= maxLength { indices.removeLast(); return indices }
}
bestEnd = n
} else {
for end in startSearch..<endSearch {
let score = (end == n) ? prior(end - current)
: probs[end - 1] * prior(end - current)
if score > bestScore { bestScore = score; bestEnd = end }
}
if bestEnd == -1 { bestEnd = min(current + maxLength, n) }
}
if bestEnd == n {
let remaining = n - current
if remaining < minLength && !indices.isEmpty {
let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0
if n - newLast <= maxLength { indices.removeLast(); return indices }
}
break
}
indices.append(bestEnd)
current = bestEnd
}
return indices
}
// MARK: - Helpers (ports of the private functions in constraints.py)
private static func fallbackGreedy(n: Int, minLength: Int, maxLength: Int) -> [Int] {
var indices = [Int]()
var curr = 0
while curr < n {
let next = min(curr + maxLength, n)
if next >= curr + minLength { indices.append(next) }
curr = next
}
return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength)
}
private static func handleShortFinalSegment(_ input: [Int], n: Int,
minLength: Int, maxLength: Int) -> [Int] {
var indices = input
guard let last = indices.last else { return indices }
let lastChunkLen = n - last
if lastChunkLen >= minLength { return indices }
if indices.count > 1 {
let prevSplit = indices[indices.count - 2]
if n - prevSplit <= maxLength {
indices.removeLast()
} else {
let desired = n - minLength
let minValid = prevSplit + 1
let adjusted = max(desired, minValid)
if adjusted - prevSplit <= maxLength {
indices[indices.count - 1] = adjusted
}
}
} else {
if n <= maxLength {
return []
} else {
let desired = n - minLength
if desired >= minLength {
indices[indices.count - 1] = desired
}
}
}
return indices
}
}
|