File size: 6,260 Bytes
357ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import Foundation

/// Segmentation algorithm. Port of `constrained_segmentation` in
/// `wtpsplit/utils/constraints.py`.
public enum SegmentationAlgorithm: String, Sendable, CaseIterable {
    case viterbi
    case greedy
}

/// Length-constrained segmentation by dynamic programming.
///
/// Maximizes `Σ log prior(segmentLength) + Σ log p(boundary)` subject to
/// `minLength ≤ segmentLength ≤ maxLength`. Returns boundary positions as 1-based
/// end indices in `[1, n)` (the terminal boundary `n` is excluded); callers form
/// `cuts = [0] + result + [n]` and slice.
enum ConstrainedSegmentation {

    static func run(probs: [Double],
                    prior: (Int) -> Double,
                    minLength: Int,
                    maxLength: Int,
                    algorithm: SegmentationAlgorithm) -> [Int] {
        let n = probs.count
        if n == 0 { return [] }

        switch algorithm {
        case .greedy:
            return greedy(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength)
        case .viterbi:
            return viterbi(probs: probs, prior: prior, n: n, minLength: minLength, maxLength: maxLength)
        }
    }

    // MARK: - Viterbi

    private static func viterbi(probs: [Double], prior: (Int) -> Double,
                                n: Int, minLength: Int, maxLength: Int) -> [Int] {
        let negInf = -Double.infinity
        var dp = [Double](repeating: negInf, count: n + 1)
        var back = [Int](repeating: 0, count: n + 1)
        dp[0] = 0.0

        for current in 1...n {
            let earliest = max(0, current - maxLength)   // segment ≤ maxLength
            let latest = current - minLength             // segment ≥ minLength
            if latest < earliest { continue }

            for segStart in earliest...latest {
                let segLen = current - segStart
                let priorProb = prior(segLen)
                if priorProb <= 0 { continue }
                if dp[segStart] == negInf { continue }

                var candidate = dp[segStart] + log(priorProb)
                if current < n {
                    // probs[i] = boundary after scalar i → boundary at `current` is probs[current-1].
                    candidate += log(probs[current - 1])
                }
                if candidate > dp[current] {
                    dp[current] = candidate
                    back[current] = segStart
                }
            }
        }

        if dp[n] == negInf {
            return fallbackGreedy(n: n, minLength: minLength, maxLength: maxLength)
        }

        var indices = [Int]()
        var pos = n
        while pos > 0 {
            indices.append(pos)
            pos = back[pos]
        }
        indices.reverse()
        if indices.last == n { indices.removeLast() }
        return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength)
    }

    // MARK: - Greedy

    private static func greedy(probs: [Double], prior: (Int) -> Double,
                               n: Int, minLength: Int, maxLength: Int) -> [Int] {
        var indices = [Int]()
        var current = 0

        while current < n {
            var bestScore = -Double.infinity
            var bestEnd = -1

            let startSearch = current + minLength
            let endSearch = min(current + maxLength + 1, n + 1)

            if startSearch >= endSearch {
                let remaining = n - current
                if remaining < minLength && !indices.isEmpty {
                    let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0
                    if n - newLast <= maxLength { indices.removeLast(); return indices }
                }
                bestEnd = n
            } else {
                for end in startSearch..<endSearch {
                    let score = (end == n) ? prior(end - current)
                                           : probs[end - 1] * prior(end - current)
                    if score > bestScore { bestScore = score; bestEnd = end }
                }
                if bestEnd == -1 { bestEnd = min(current + maxLength, n) }
            }

            if bestEnd == n {
                let remaining = n - current
                if remaining < minLength && !indices.isEmpty {
                    let newLast = indices.count >= 2 ? indices[indices.count - 2] : 0
                    if n - newLast <= maxLength { indices.removeLast(); return indices }
                }
                break
            }

            indices.append(bestEnd)
            current = bestEnd
        }
        return indices
    }

    // MARK: - Helpers (ports of the private functions in constraints.py)

    private static func fallbackGreedy(n: Int, minLength: Int, maxLength: Int) -> [Int] {
        var indices = [Int]()
        var curr = 0
        while curr < n {
            let next = min(curr + maxLength, n)
            if next >= curr + minLength { indices.append(next) }
            curr = next
        }
        return handleShortFinalSegment(indices, n: n, minLength: minLength, maxLength: maxLength)
    }

    private static func handleShortFinalSegment(_ input: [Int], n: Int,
                                                minLength: Int, maxLength: Int) -> [Int] {
        var indices = input
        guard let last = indices.last else { return indices }

        let lastChunkLen = n - last
        if lastChunkLen >= minLength { return indices }

        if indices.count > 1 {
            let prevSplit = indices[indices.count - 2]
            if n - prevSplit <= maxLength {
                indices.removeLast()
            } else {
                let desired = n - minLength
                let minValid = prevSplit + 1
                let adjusted = max(desired, minValid)
                if adjusted - prevSplit <= maxLength {
                    indices[indices.count - 1] = adjusted
                }
            }
        } else {
            if n <= maxLength {
                return []
            } else {
                let desired = n - minLength
                if desired >= minLength {
                    indices[indices.count - 1] = desired
                }
            }
        }
        return indices
    }
}