| import Foundation |
|
|
| |
| |
| public enum SegmentationAlgorithm: String, Sendable, CaseIterable { |
| case viterbi |
| case greedy |
| } |
|
|
| |
| |
| |
| |
| |
| |
| enum ConstrainedSegmentation { |
|
|
| static func run(probs: [Double], |
| prior: (Int) -> Double, |
| minLength: Int, |
| maxLength: Int, |
| algorithm: SegmentationAlgorithm) -> [Int] { |
| let n = probs.count |
| if n == 0 { return [] } |
|
|
| switch algorithm { |
| case .greedy: |
| return greedy(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength) |
| case .viterbi: |
| return viterbi(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength) |
| } |
| } |
|
|
| |
|
|
| private static func viterbi(probs: [Double], prior: (Int) -> Double, |
| n: Int, minLength: Int, maxLength: Int) -> [Int] { |
| let negInf = -Double.infinity |
| var dp = [Double](repeating: negInf, count: n + 1) |
| var back = [Int](repeating: 0, count: n + 1) |
| dp[0] = 0.0 |
|
|
| for current in 1...n { |
| let earliest = max(0, current - maxLength) |
| let latest = current - minLength |
| if latest < earliest { continue } |
|
|
| for segStart in earliest...latest { |
| let segLen = current - segStart |
| let priorProb = prior(segLen) |
| if priorProb <= 0 { continue } |
| if dp[segStart] == negInf { continue } |
|
|
| var candidate = dp[segStart] + log(priorProb) |
| if current < n { |
| |
| candidate += log(probs[current - 1]) |
| } |
| if candidate > dp[current] { |
| dp[current] = candidate |
| back[current] = segStart |
| } |
| } |
| } |
|
|
| if dp[n] == negInf { |
| return fallbackGreedy(n: n, minLength: minLength, maxLength: maxLength) |
| } |
|
|
| var indices = [Int]() |
| var pos = n |
| while pos > 0 { |
| indices.append(pos) |
| pos = back[pos] |
| } |
| indices.reverse() |
| if indices.last == n { indices.removeLast() } |
| return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength) |
| } |
|
|
| |
|
|
| private static func greedy(probs: [Double], prior: (Int) -> Double, |
| n: Int, minLength: Int, maxLength: Int) -> [Int] { |
| var indices = [Int]() |
| var current = 0 |
|
|
| while current < n { |
| var bestScore = -Double.infinity |
| var bestEnd = -1 |
|
|
| let startSearch = current + minLength |
| let endSearch = min(current + maxLength + 1, n + 1) |
|
|
| if startSearch >= endSearch { |
| let remaining = n - current |
| if remaining < minLength && !indices.isEmpty { |
| let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0 |
| if n - newLast <= maxLength { indices.removeLast(); return indices } |
| } |
| bestEnd = n |
| } else { |
| for end in startSearch..<endSearch { |
| let score = (end == n) ? prior(end - current) |
| : probs[end - 1] * prior(end - current) |
| if score > bestScore { bestScore = score; bestEnd = end } |
| } |
| if bestEnd == -1 { bestEnd = min(current + maxLength, n) } |
| } |
|
|
| if bestEnd == n { |
| let remaining = n - current |
| if remaining < minLength && !indices.isEmpty { |
| let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0 |
| if n - newLast <= maxLength { indices.removeLast(); return indices } |
| } |
| break |
| } |
|
|
| indices.append(bestEnd) |
| current = bestEnd |
| } |
| return indices |
| } |
|
|
| |
|
|
| private static func fallbackGreedy(n: Int, minLength: Int, maxLength: Int) -> [Int] { |
| var indices = [Int]() |
| var curr = 0 |
| while curr < n { |
| let next = min(curr + maxLength, n) |
| if next >= curr + minLength { indices.append(next) } |
| curr = next |
| } |
| return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength) |
| } |
|
|
| private static func handleShortFinalSegment(_ input: [Int], n: Int, |
| minLength: Int, maxLength: Int) -> [Int] { |
| var indices = input |
| guard let last = indices.last else { return indices } |
|
|
| let lastChunkLen = n - last |
| if lastChunkLen >= minLength { return indices } |
|
|
| if indices.count > 1 { |
| let prevSplit = indices[indices.count - 2] |
| if n - prevSplit <= maxLength { |
| indices.removeLast() |
| } else { |
| let desired = n - minLength |
| let minValid = prevSplit + 1 |
| let adjusted = max(desired, minValid) |
| if adjusted - prevSplit <= maxLength { |
| indices[indices.count - 1] = adjusted |
| } |
| } |
| } else { |
| if n <= maxLength { |
| return [] |
| } else { |
| let desired = n - minLength |
| if desired >= minLength { |
| indices[indices.count - 1] = desired |
| } |
| } |
| } |
| return indices |
| } |
| } |
|
|