import Foundation /// Segmentation algorithm. Port of `constrained_segmentation` in /// `wtpsplit/utils/constraints.py`. public enum SegmentationAlgorithm: String, Sendable, CaseIterable { case viterbi case greedy } /// Length-constrained segmentation by dynamic programming. /// /// Maximizes `Σ log prior(segmentLength) + Σ log p(boundary)` subject to /// `minLength ≤ segmentLength ≤ maxLength`. Returns boundary positions as 1-based /// end indices in `[1, n)` (the terminal boundary `n` is excluded); callers form /// `cuts = [0] + result + [n]` and slice. enum ConstrainedSegmentation { static func run(probs: [Double], prior: (Int) -> Double, minLength: Int, maxLength: Int, algorithm: SegmentationAlgorithm) -> [Int] { let n = probs.count if n == 0 { return [] } switch algorithm { case .greedy: return greedy(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength) case .viterbi: return viterbi(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength) } } // MARK: - Viterbi private static func viterbi(probs: [Double], prior: (Int) -> Double, n: Int, minLength: Int, maxLength: Int) -> [Int] { let negInf = -Double.infinity var dp = [Double](repeating: negInf, count: n + 1) var back = [Int](repeating: 0, count: n + 1) dp[0] = 0.0 for current in 1...n { let earliest = max(0, current - maxLength) // segment ≤ maxLength let latest = current - minLength // segment ≥ minLength if latest < earliest { continue } for segStart in earliest...latest { let segLen = current - segStart let priorProb = prior(segLen) if priorProb <= 0 { continue } if dp[segStart] == negInf { continue } var candidate = dp[segStart] + log(priorProb) if current < n { // probs[i] = boundary after scalar i → boundary at `current` is probs[current-1]. candidate += log(probs[current - 1]) } if candidate > dp[current] { dp[current] = candidate back[current] = segStart } } } if dp[n] == negInf { return fallbackGreedy(n: n, minLength: minLength, maxLength: maxLength) } var indices = [Int]() var pos = n while pos > 0 { indices.append(pos) pos = back[pos] } indices.reverse() if indices.last == n { indices.removeLast() } return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength) } // MARK: - Greedy private static func greedy(probs: [Double], prior: (Int) -> Double, n: Int, minLength: Int, maxLength: Int) -> [Int] { var indices = [Int]() var current = 0 while current < n { var bestScore = -Double.infinity var bestEnd = -1 let startSearch = current + minLength let endSearch = min(current + maxLength + 1, n + 1) if startSearch >= endSearch { let remaining = n - current if remaining < minLength && !indices.isEmpty { let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0 if n - newLast <= maxLength { indices.removeLast(); return indices } } bestEnd = n } else { for end in startSearch.. bestScore { bestScore = score; bestEnd = end } } if bestEnd == -1 { bestEnd = min(current + maxLength, n) } } if bestEnd == n { let remaining = n - current if remaining < minLength && !indices.isEmpty { let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0 if n - newLast <= maxLength { indices.removeLast(); return indices } } break } indices.append(bestEnd) current = bestEnd } return indices } // MARK: - Helpers (ports of the private functions in constraints.py) private static func fallbackGreedy(n: Int, minLength: Int, maxLength: Int) -> [Int] { var indices = [Int]() var curr = 0 while curr < n { let next = min(curr + maxLength, n) if next >= curr + minLength { indices.append(next) } curr = next } return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength) } private static func handleShortFinalSegment(_ input: [Int], n: Int, minLength: Int, maxLength: Int) -> [Int] { var indices = input guard let last = indices.last else { return indices } let lastChunkLen = n - last if lastChunkLen >= minLength { return indices } if indices.count > 1 { let prevSplit = indices[indices.count - 2] if n - prevSplit <= maxLength { indices.removeLast() } else { let desired = n - minLength let minValid = prevSplit + 1 let adjusted = max(desired, minValid) if adjusted - prevSplit <= maxLength { indices[indices.count - 1] = adjusted } } } else { if n <= maxLength { return [] } else { let desired = n - minLength if desired >= minLength { indices[indices.count - 1] = desired } } } return indices } }