File size: 4,690 Bytes
357ae2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import Foundation
import WtpsplitKit

// Minimal CLI mirroring scripts/run_segmentation.py for local validation.
//
//   wtpseg --model <path.mlpackage> --text "..." [options]
//   wtpseg --model <path.mlpackage> --file in.txt --max-length 80 --min-length 40
//
// Options: --max-length --min-length --overflow --prior {uniform,gaussian,
//          clipped_polynomial} --target --spread --algorithm {viterbi,greedy}
//          --allow-midword --vocab {auto,full,en_zh} --tokens (debug: print ids)

func fail(_ msg: String) -> Never {
    FileHandle.standardError.write(Data(("error: " + msg + "\n").utf8))
    exit(1)
}

var args = Array(CommandLine.arguments.dropFirst())
func popValue(_ flag: String) -> String? {
    guard let i = args.firstIndex(of: flag), i + 1 < args.count else { return nil }
    let v = args[i + 1]
    args.removeSubrange(i...(i + 1))
    return v
}
func popFlag(_ flag: String) -> Bool {
    guard let i = args.firstIndex(of: flag) else { return false }
    args.remove(at: i)
    return true
}

guard let modelPath = popValue("--model") ?? popValue("-m") else {
    fail("--model <path to .mlpackage or .mlmodelc> is required")
}
let printTokens = popFlag("--tokens")
let dumpProbs = popFlag("--dump-probs")
let dumpChunks = popFlag("--dump-chunks")

var opts = SegmentationOptions()
if let v = popValue("--max-length") { opts.maxLength = Int(v) ?? opts.maxLength }
if let v = popValue("--min-length") { opts.minLength = Int(v) ?? opts.minLength }
if let v = popValue("--overflow") { opts.overflow = Int(v) ?? opts.overflow }
if let v = popValue("--target") { opts.targetLength = Int(v) ?? opts.targetLength }
if let v = popValue("--spread") { opts.spread = Int(v) ?? opts.spread }
if let v = popValue("--prior") {
    switch v {
    case "uniform": opts.prior = .uniform
    case "gaussian": opts.prior = .gaussian
    case "clipped_polynomial", "clippedPolynomial": opts.prior = .clippedPolynomial
    default: fail("unknown prior: \(v)")
    }
}
if let v = popValue("--algorithm") {
    switch v {
    case "viterbi": opts.algorithm = .viterbi
    case "greedy": opts.algorithm = .greedy
    default: fail("unknown algorithm: \(v)")
    }
}
opts.allowMidword = popFlag("--allow-midword")

var vocab: Vocabulary = .auto
if let v = popValue("--vocab") {
    switch v {
    case "auto": vocab = .auto
    case "full": vocab = .full
    case "en_zh", "enZh": vocab = .enZh
    default: fail("unknown vocab: \(v)")
    }
}

let text: String
if let t = popValue("--text") ?? popValue("-t") {
    text = t
} else if let f = popValue("--file") ?? popValue("-f") {
    text = (try? String(contentsOfFile: f, encoding: .utf8)) ?? { fail("cannot read \(f)") }()
} else {
    text = "Breaking News: Scientists announced a discovery. 这是一个测试。It works well!"
}

do {
    let url = URL(fileURLWithPath: modelPath)
    let clock = Date()
    let segmenter = try SaTSegmenter(modelURL: url, vocabulary: vocab)

    if printTokens {
        let t = segmenter.debugTokens(text)
        var lines = [String]()
        for k in 0..<t.ids.count { lines.append("\(t.ids[k]):\(t.charEnds[k])") }
        print("TOKENS " + lines.joined(separator: " "))
        exit(0)
    }
    if dumpProbs {
        let probs = try segmenter.debugCharProbs(text)
        print("PROBS " + probs.map { String(format: "%.9g", $0) }.joined(separator: " "))
        exit(0)
    }
    if dumpChunks {
        let chunks = try segmenter.segment(text, options: opts)
        // Record-separator delimited so chunks compare byte-exactly.
        print("CHUNKS" + "\u{1e}" + chunks.joined(separator: "\u{1e}"))
        exit(0)
    }
    let loadMs = Date().timeIntervalSince(clock) * 1000

    let t0 = Date()
    let chunks = try segmenter.segment(text, options: opts)
    let segMs = Date().timeIntervalSince(t0) * 1000

    print(String(format: "Model load+compile: %.0f ms   segment: %.1f ms", loadMs, segMs))
    print("Config: max=\(opts.maxLength) overflow=\(opts.overflow) min=\(opts.minLength) "
        + "prior=\(opts.prior) algo=\(opts.algorithm)")
    print("Input: \(text.unicodeScalars.count) chars -> \(chunks.count) chunks\n")
    let hardMax = opts.maxLength + max(0, opts.overflow)
    for c in chunks {
        let n = c.unicodeScalars.count
        let flag = n > hardMax ? "!" : (n > opts.maxLength ? "+" : " ")
        let body = c.trimmingCharacters(in: .whitespacesAndNewlines)
        print("  \(flag)[\(String(format: "%3d", n))] \(body.prefix(90))")
    }
    let rejoined = chunks.joined()
    print(rejoined == text ? "\n  ✓ text preserved (chunks rejoin to original)"
                           : "\n  ✗ TEXT NOT PRESERVED")
} catch {
    fail("\(error)")
}