krmanik's picture
Upload folder using huggingface_hub
357ae2c verified
import Foundation
/// Port of `pause_aware_mask` (a.k.a. `word_safe_mask`) from
/// `scripts/run_segmentation.py`.
///
/// Biases forced breaks toward natural prosodic pauses so TTS does not pause
/// mid-phrase. `probs[i]` is the boundary probability *after* scalar `i` (between
/// `i` and `i+1`). Model-predicted sentence boundaries keep dominating; for
/// everything else a floor is raised by pause strength:
/// clause punctuation > connector word > plain word gap > between-hanzi,
/// and mid-word positions are driven to ~0 so words/abbreviations are never cut.
enum PauseMask {
// Break-priority floors (identical to the reference constants).
private static let floorClause = 0.25 // , ; : ) ] } 、 , …
private static let floorConnector = 0.05 // before and/which/that…
private static let floorHanzi = 5e-3 // between two Chinese chars
private static let floorSpace = 1e-4 // plain word gap
private static let cjkSentenceFloor = 0.9 // 。!?…
private static let forbid = 1e-9 // mid-word / abbreviation
/// Apply the mask in place-style and return the adjusted probabilities.
/// `scalars` is the original text as Unicode scalars; `probs.count == scalars.count`.
static func apply(_ probs: [Double], scalars: [Unicode.Scalar]) -> [Double] {
let n = scalars.count
guard n > 1 else { return probs }
var p = probs
let connectors = connectorBreakPositions(scalars)
// Never break before the end-of-text marker → iterate 0..<n-1.
for i in 0..<(n - 1) {
let ch = scalars[i]
let nxt = scalars[i + 1]
let endsToken = CharClasses.isSpace(nxt) || CharClasses.isCJK(nxt)
if CharClasses.clausePunct.contains(ch) && endsToken {
p[i] = max(p[i], floorClause)
} else if CharClasses.cjkSentencePunct.contains(ch) {
p[i] = max(p[i], cjkSentenceFloor)
} else if connectors.contains(i) {
p[i] = max(p[i], floorConnector)
} else if CharClasses.isSpace(nxt) || CharClasses.isSpace(ch) {
p[i] = max(p[i], floorSpace)
} else if CharClasses.isCJK(ch) && CharClasses.isCJK(nxt) {
p[i] = max(p[i], floorHanzi)
} else {
p[i] = min(p[i], forbid)
}
}
return p
}
/// Indices `i` (break after scalar `i`) that sit right before a connector
/// word. Mirrors the regex `\s+(\S+)` scan in the reference: for each
/// whitespace run, inspect the following word, strip ASCII punctuation,
/// lowercase it, and if it is a connector mark the scalar just before the
/// whitespace run (the last char of the preceding word).
private static func connectorBreakPositions(_ scalars: [Unicode.Scalar]) -> Set<Int> {
var positions = Set<Int>()
let n = scalars.count
var i = 0
while i < n {
guard CharClasses.isSpace(scalars[i]) else { i += 1; continue }
let runStart = i
while i < n && CharClasses.isSpace(scalars[i]) { i += 1 }
// Word = following non-space run.
let wordStart = i
while i < n && !CharClasses.isSpace(scalars[i]) { i += 1 }
guard wordStart < i, runStart - 1 >= 0 else { continue }
let word = stripAsciiPunctuation(scalars[wordStart..<i]).lowercased()
if CharClasses.connectors.contains(word) {
positions.insert(runStart - 1)
}
}
return positions
}
/// Equivalent of Python `str.strip(string.punctuation)`.
private static func stripAsciiPunctuation(_ slice: ArraySlice<Unicode.Scalar>) -> String {
var lo = slice.startIndex
var hi = slice.endIndex
while lo < hi && CharClasses.asciiPunctuation.contains(slice[lo]) { lo += 1 }
while hi > lo && CharClasses.asciiPunctuation.contains(slice[hi - 1]) { hi -= 1 }
var view = String.UnicodeScalarView()
view.append(contentsOf: slice[lo..<hi])
return String(view)
}
}