wtpsplit-kit / Sources /WtpsplitKit /Segmentation /ConstrainedSegmentation.swift
krmanik's picture
Upload folder using huggingface_hub
357ae2c verified
import Foundation
/// Segmentation algorithm. Port of `constrained_segmentation` in
/// `wtpsplit/utils/constraints.py`.
public enum SegmentationAlgorithm: String, Sendable, CaseIterable {
case viterbi
case greedy
}
/// Length-constrained segmentation by dynamic programming.
///
/// Maximizes `Σ log prior(segmentLength) + Σ log p(boundary)` subject to
/// `minLength ≤ segmentLength ≤ maxLength`. Returns boundary positions as 1-based
/// end indices in `[1, n)` (the terminal boundary `n` is excluded); callers form
/// `cuts = [0] + result + [n]` and slice.
enum ConstrainedSegmentation {
static func run(probs: [Double],
prior: (Int) -> Double,
minLength: Int,
maxLength: Int,
algorithm: SegmentationAlgorithm) -> [Int] {
let n = probs.count
if n == 0 { return [] }
switch algorithm {
case .greedy:
return greedy(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength)
case .viterbi:
return viterbi(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength)
}
}
// MARK: - Viterbi
private static func viterbi(probs: [Double], prior: (Int) -> Double,
n: Int, minLength: Int, maxLength: Int) -> [Int] {
let negInf = -Double.infinity
var dp = [Double](repeating: negInf, count: n + 1)
var back = [Int](repeating: 0, count: n + 1)
dp[0] = 0.0
for current in 1...n {
let earliest = max(0, current - maxLength) // segment ≤ maxLength
let latest = current - minLength // segment ≥ minLength
if latest < earliest { continue }
for segStart in earliest...latest {
let segLen = current - segStart
let priorProb = prior(segLen)
if priorProb <= 0 { continue }
if dp[segStart] == negInf { continue }
var candidate = dp[segStart] + log(priorProb)
if current < n {
// probs[i] = boundary after scalar i → boundary at `current` is probs[current-1].
candidate += log(probs[current - 1])
}
if candidate > dp[current] {
dp[current] = candidate
back[current] = segStart
}
}
}
if dp[n] == negInf {
return fallbackGreedy(n: n, minLength: minLength, maxLength: maxLength)
}
var indices = [Int]()
var pos = n
while pos > 0 {
indices.append(pos)
pos = back[pos]
}
indices.reverse()
if indices.last == n { indices.removeLast() }
return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength)
}
// MARK: - Greedy
private static func greedy(probs: [Double], prior: (Int) -> Double,
n: Int, minLength: Int, maxLength: Int) -> [Int] {
var indices = [Int]()
var current = 0
while current < n {
var bestScore = -Double.infinity
var bestEnd = -1
let startSearch = current + minLength
let endSearch = min(current + maxLength + 1, n + 1)
if startSearch >= endSearch {
let remaining = n - current
if remaining < minLength && !indices.isEmpty {
let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0
if n - newLast <= maxLength { indices.removeLast(); return indices }
}
bestEnd = n
} else {
for end in startSearch..<endSearch {
let score = (end == n) ? prior(end - current)
: probs[end - 1] * prior(end - current)
if score > bestScore { bestScore = score; bestEnd = end }
}
if bestEnd == -1 { bestEnd = min(current + maxLength, n) }
}
if bestEnd == n {
let remaining = n - current
if remaining < minLength && !indices.isEmpty {
let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0
if n - newLast <= maxLength { indices.removeLast(); return indices }
}
break
}
indices.append(bestEnd)
current = bestEnd
}
return indices
}
// MARK: - Helpers (ports of the private functions in constraints.py)
private static func fallbackGreedy(n: Int, minLength: Int, maxLength: Int) -> [Int] {
var indices = [Int]()
var curr = 0
while curr < n {
let next = min(curr + maxLength, n)
if next >= curr + minLength { indices.append(next) }
curr = next
}
return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength)
}
private static func handleShortFinalSegment(_ input: [Int], n: Int,
minLength: Int, maxLength: Int) -> [Int] {
var indices = input
guard let last = indices.last else { return indices }
let lastChunkLen = n - last
if lastChunkLen >= minLength { return indices }
if indices.count > 1 {
let prevSplit = indices[indices.count - 2]
if n - prevSplit <= maxLength {
indices.removeLast()
} else {
let desired = n - minLength
let minValid = prevSplit + 1
let adjusted = max(desired, minValid)
if adjusted - prevSplit <= maxLength {
indices[indices.count - 1] = adjusted
}
}
} else {
if n <= maxLength {
return []
} else {
let desired = n - minLength
if desired >= minLength {
indices[indices.count - 1] = desired
}
}
}
return indices
}
}