| import Foundation |
| import WtpsplitKit |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| func fail(_ msg: String) -> Never { |
| FileHandle.standardError.write(Data(("error: " + msg + "\n").utf8)) |
| exit(1) |
| } |
|
|
| var args = Array(CommandLine.arguments.dropFirst()) |
| func popValue(_ flag: String) -> String? { |
| guard let i = args.firstIndex(of: flag), i + 1 < args.count else { return nil } |
| let v = args[i + 1] |
| args.removeSubrange(i...(i + 1)) |
| return v |
| } |
| func popFlag(_ flag: String) -> Bool { |
| guard let i = args.firstIndex(of: flag) else { return false } |
| args.remove(at: i) |
| return true |
| } |
|
|
| guard let modelPath = popValue("--model") ?? popValue("-m") else { |
| fail("--model <path to .mlpackage or .mlmodelc> is required") |
| } |
| let printTokens = popFlag("--tokens") |
| let dumpProbs = popFlag("--dump-probs") |
| let dumpChunks = popFlag("--dump-chunks") |
|
|
| var opts = SegmentationOptions() |
| if let v = popValue("--max-length") { opts.maxLength = Int(v) ?? opts.maxLength } |
| if let v = popValue("--min-length") { opts.minLength = Int(v) ?? opts.minLength } |
| if let v = popValue("--overflow") { opts.overflow = Int(v) ?? opts.overflow } |
| if let v = popValue("--target") { opts.targetLength = Int(v) ?? opts.targetLength } |
| if let v = popValue("--spread") { opts.spread = Int(v) ?? opts.spread } |
| if let v = popValue("--prior") { |
| switch v { |
| case "uniform": opts.prior = .uniform |
| case "gaussian": opts.prior = .gaussian |
| case "clipped_polynomial", "clippedPolynomial": opts.prior = .clippedPolynomial |
| default: fail("unknown prior: \(v)") |
| } |
| } |
| if let v = popValue("--algorithm") { |
| switch v { |
| case "viterbi": opts.algorithm = .viterbi |
| case "greedy": opts.algorithm = .greedy |
| default: fail("unknown algorithm: \(v)") |
| } |
| } |
| opts.allowMidword = popFlag("--allow-midword") |
|
|
| var vocab: Vocabulary = .auto |
| if let v = popValue("--vocab") { |
| switch v { |
| case "auto": vocab = .auto |
| case "full": vocab = .full |
| case "en_zh", "enZh": vocab = .enZh |
| default: fail("unknown vocab: \(v)") |
| } |
| } |
|
|
| let text: String |
| if let t = popValue("--text") ?? popValue("-t") { |
| text = t |
| } else if let f = popValue("--file") ?? popValue("-f") { |
| text = (try? String(contentsOfFile: f, encoding: .utf8)) ?? { fail("cannot read \(f)") }() |
| } else { |
| text = "Breaking News: Scientists announced a discovery. 这是一个测试。It works well!" |
| } |
|
|
| do { |
| let url = URL(fileURLWithPath: modelPath) |
| let clock = Date() |
| let segmenter = try SaTSegmenter(modelURL: url, vocabulary: vocab) |
|
|
| if printTokens { |
| let t = segmenter.debugTokens(text) |
| var lines = [String]() |
| for k in 0..<t.ids.count { lines.append("\(t.ids[k]):\(t.charEnds[k])") } |
| print("TOKENS " + lines.joined(separator: " ")) |
| exit(0) |
| } |
| if dumpProbs { |
| let probs = try segmenter.debugCharProbs(text) |
| print("PROBS " + probs.map { String(format: "%.9g", $0) }.joined(separator: " ")) |
| exit(0) |
| } |
| if dumpChunks { |
| let chunks = try segmenter.segment(text, options: opts) |
| |
| print("CHUNKS" + "\u{1e}" + chunks.joined(separator: "\u{1e}")) |
| exit(0) |
| } |
| let loadMs = Date().timeIntervalSince(clock) * 1000 |
|
|
| let t0 = Date() |
| let chunks = try segmenter.segment(text, options: opts) |
| let segMs = Date().timeIntervalSince(t0) * 1000 |
|
|
| print(String(format: "Model load+compile: %.0f ms segment: %.1f ms", loadMs, segMs)) |
| print("Config: max=\(opts.maxLength) overflow=\(opts.overflow) min=\(opts.minLength) " |
| + "prior=\(opts.prior) algo=\(opts.algorithm)") |
| print("Input: \(text.unicodeScalars.count) chars -> \(chunks.count) chunks\n") |
| let hardMax = opts.maxLength + max(0, opts.overflow) |
| for c in chunks { |
| let n = c.unicodeScalars.count |
| let flag = n > hardMax ? "!" : (n > opts.maxLength ? "+" : " ") |
| let body = c.trimmingCharacters(in: .whitespacesAndNewlines) |
| print(" \(flag)[\(String(format: "%3d", n))] \(body.prefix(90))") |
| } |
| let rejoined = chunks.joined() |
| print(rejoined == text ? "\n ✓ text preserved (chunks rejoin to original)" |
| : "\n ✗ TEXT NOT PRESERVED") |
| } catch { |
| fail("\(error)") |
| } |
|
|