wtpsplit-kit / Sources /wtpseg /main.swift
krmanik's picture
Upload folder using huggingface_hub
357ae2c verified
import Foundation
import WtpsplitKit
// Minimal CLI mirroring scripts/run_segmentation.py for local validation.
//
// wtpseg --model <path.mlpackage> --text "..." [options]
// wtpseg --model <path.mlpackage> --file in.txt --max-length 80 --min-length 40
//
// Options: --max-length --min-length --overflow --prior {uniform,gaussian,
// clipped_polynomial} --target --spread --algorithm {viterbi,greedy}
// --allow-midword --vocab {auto,full,en_zh} --tokens (debug: print ids)
func fail(_ msg: String) -> Never {
FileHandle.standardError.write(Data(("error: " + msg + "\n").utf8))
exit(1)
}
var args = Array(CommandLine.arguments.dropFirst())
func popValue(_ flag: String) -> String? {
guard let i = args.firstIndex(of: flag), i + 1 < args.count else { return nil }
let v = args[i + 1]
args.removeSubrange(i...(i + 1))
return v
}
func popFlag(_ flag: String) -> Bool {
guard let i = args.firstIndex(of: flag) else { return false }
args.remove(at: i)
return true
}
guard let modelPath = popValue("--model") ?? popValue("-m") else {
fail("--model <path to .mlpackage or .mlmodelc> is required")
}
let printTokens = popFlag("--tokens")
let dumpProbs = popFlag("--dump-probs")
let dumpChunks = popFlag("--dump-chunks")
var opts = SegmentationOptions()
if let v = popValue("--max-length") { opts.maxLength = Int(v) ?? opts.maxLength }
if let v = popValue("--min-length") { opts.minLength = Int(v) ?? opts.minLength }
if let v = popValue("--overflow") { opts.overflow = Int(v) ?? opts.overflow }
if let v = popValue("--target") { opts.targetLength = Int(v) ?? opts.targetLength }
if let v = popValue("--spread") { opts.spread = Int(v) ?? opts.spread }
if let v = popValue("--prior") {
switch v {
case "uniform": opts.prior = .uniform
case "gaussian": opts.prior = .gaussian
case "clipped_polynomial", "clippedPolynomial": opts.prior = .clippedPolynomial
default: fail("unknown prior: \(v)")
}
}
if let v = popValue("--algorithm") {
switch v {
case "viterbi": opts.algorithm = .viterbi
case "greedy": opts.algorithm = .greedy
default: fail("unknown algorithm: \(v)")
}
}
opts.allowMidword = popFlag("--allow-midword")
var vocab: Vocabulary = .auto
if let v = popValue("--vocab") {
switch v {
case "auto": vocab = .auto
case "full": vocab = .full
case "en_zh", "enZh": vocab = .enZh
default: fail("unknown vocab: \(v)")
}
}
let text: String
if let t = popValue("--text") ?? popValue("-t") {
text = t
} else if let f = popValue("--file") ?? popValue("-f") {
text = (try? String(contentsOfFile: f, encoding: .utf8)) ?? { fail("cannot read \(f)") }()
} else {
text = "Breaking News: Scientists announced a discovery. 这是一个测试。It works well!"
}
do {
let url = URL(fileURLWithPath: modelPath)
let clock = Date()
let segmenter = try SaTSegmenter(modelURL: url, vocabulary: vocab)
if printTokens {
let t = segmenter.debugTokens(text)
var lines = [String]()
for k in 0..<t.ids.count { lines.append("\(t.ids[k]):\(t.charEnds[k])") }
print("TOKENS " + lines.joined(separator: " "))
exit(0)
}
if dumpProbs {
let probs = try segmenter.debugCharProbs(text)
print("PROBS " + probs.map { String(format: "%.9g", $0) }.joined(separator: " "))
exit(0)
}
if dumpChunks {
let chunks = try segmenter.segment(text, options: opts)
// Record-separator delimited so chunks compare byte-exactly.
print("CHUNKS" + "\u{1e}" + chunks.joined(separator: "\u{1e}"))
exit(0)
}
let loadMs = Date().timeIntervalSince(clock) * 1000
let t0 = Date()
let chunks = try segmenter.segment(text, options: opts)
let segMs = Date().timeIntervalSince(t0) * 1000
print(String(format: "Model load+compile: %.0f ms segment: %.1f ms", loadMs, segMs))
print("Config: max=\(opts.maxLength) overflow=\(opts.overflow) min=\(opts.minLength) "
+ "prior=\(opts.prior) algo=\(opts.algorithm)")
print("Input: \(text.unicodeScalars.count) chars -> \(chunks.count) chunks\n")
let hardMax = opts.maxLength + max(0, opts.overflow)
for c in chunks {
let n = c.unicodeScalars.count
let flag = n > hardMax ? "!" : (n > opts.maxLength ? "+" : " ")
let body = c.trimmingCharacters(in: .whitespacesAndNewlines)
print(" \(flag)[\(String(format: "%3d", n))] \(body.prefix(90))")
}
let rejoined = chunks.joined()
print(rejoined == text ? "\n ✓ text preserved (chunks rejoin to original)"
: "\n ✗ TEXT NOT PRESERVED")
} catch {
fail("\(error)")
}