File size: 4,152 Bytes
357ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import Foundation

/// Port of `pause_aware_mask` (a.k.a. `word_safe_mask`) from
/// `scripts/run_segmentation.py`.
///
/// Biases forced breaks toward natural prosodic pauses so TTS does not pause
/// mid-phrase. `probs[i]` is the boundary probability *after* scalar `i` (between
/// `i` and `i+1`). Model-predicted sentence boundaries keep dominating; for
/// everything else a floor is raised by pause strength:
///   clause punctuation > connector word > plain word gap > between-hanzi,
/// and mid-word positions are driven to ~0 so words/abbreviations are never cut.
enum PauseMask {

    // Break-priority floors (identical to the reference constants).
    private static let floorClause = 0.25      // , ; : ) ] } 、 , …
    private static let floorConnector = 0.05   // before and/which/that…
    private static let floorHanzi = 5e-3       // between two Chinese chars
    private static let floorSpace = 1e-4        // plain word gap
    private static let cjkSentenceFloor = 0.9   // 。!?…
    private static let forbid = 1e-9            // mid-word / abbreviation

    /// Apply the mask in place-style and return the adjusted probabilities.
    /// `scalars` is the original text as Unicode scalars; `probs.count == scalars.count`.
    static func apply(_ probs: [Double], scalars: [Unicode.Scalar]) -> [Double] {
        let n = scalars.count
        guard n > 1 else { return probs }
        var p = probs
        let connectors = connectorBreakPositions(scalars)

        // Never break before the end-of-text marker → iterate 0..<n-1.
        for i in 0..<(n - 1) {
            let ch = scalars[i]
            let nxt = scalars[i + 1]
            let endsToken = CharClasses.isSpace(nxt) || CharClasses.isCJK(nxt)

            if CharClasses.clausePunct.contains(ch) && endsToken {
                p[i] = max(p[i], floorClause)
            } else if CharClasses.cjkSentencePunct.contains(ch) {
                p[i] = max(p[i], cjkSentenceFloor)
            } else if connectors.contains(i) {
                p[i] = max(p[i], floorConnector)
            } else if CharClasses.isSpace(nxt) || CharClasses.isSpace(ch) {
                p[i] = max(p[i], floorSpace)
            } else if CharClasses.isCJK(ch) && CharClasses.isCJK(nxt) {
                p[i] = max(p[i], floorHanzi)
            } else {
                p[i] = min(p[i], forbid)
            }
        }
        return p
    }

    /// Indices `i` (break after scalar `i`) that sit right before a connector
    /// word. Mirrors the regex `\s+(\S+)` scan in the reference: for each
    /// whitespace run, inspect the following word, strip ASCII punctuation,
    /// lowercase it, and if it is a connector mark the scalar just before the
    /// whitespace run (the last char of the preceding word).
    private static func connectorBreakPositions(_ scalars: [Unicode.Scalar]) -> Set<Int> {
        var positions = Set<Int>()
        let n = scalars.count
        var i = 0
        while i < n {
            guard CharClasses.isSpace(scalars[i]) else { i += 1; continue }
            let runStart = i
            while i < n && CharClasses.isSpace(scalars[i]) { i += 1 }
            // Word = following non-space run.
            let wordStart = i
            while i < n && !CharClasses.isSpace(scalars[i]) { i += 1 }
            guard wordStart < i, runStart - 1 >= 0 else { continue }

            let word = stripAsciiPunctuation(scalars[wordStart..<i]).lowercased()
            if CharClasses.connectors.contains(word) {
                positions.insert(runStart - 1)
            }
        }
        return positions
    }

    /// Equivalent of Python `str.strip(string.punctuation)`.
    private static func stripAsciiPunctuation(_ slice: ArraySlice<Unicode.Scalar>) -> String {
        var lo = slice.startIndex
        var hi = slice.endIndex
        while lo < hi && CharClasses.asciiPunctuation.contains(slice[lo]) { lo += 1 }
        while hi > lo && CharClasses.asciiPunctuation.contains(slice[hi - 1]) { hi -= 1 }
        var view = String.UnicodeScalarView()
        view.append(contentsOf: slice[lo..<hi])
        return String(view)
    }
}