| """pico-type CLI: classify content from stdin/file/clipboard.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
|
|
| import numpy as np |
|
|
| from .labels import ALL_HEADS, COARSE_LABELS, MODALITY_LABELS, SUBTYPE_LABELS, CODE_LANG_LABELS, TEXT_LANG_LABELS, FILE_MIME_LABELS, RISK_LABELS |
|
|
| LABEL_TABLES = { |
| "coarse": COARSE_LABELS, |
| "modality": MODALITY_LABELS, |
| "subtype": SUBTYPE_LABELS, |
| "code_lang": CODE_LANG_LABELS, |
| "text_lang": TEXT_LANG_LABELS, |
| "file_mime": FILE_MIME_LABELS, |
| "risk": RISK_LABELS, |
| } |
|
|
|
|
| def load_onnx_model(tier: str = "base", model_dir: str = "checkpoints"): |
| import onnxruntime as ort |
| path = os.path.join(model_dir, f"picotype_{tier}.onnx") |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"ONNX model not found: {path}") |
| session = ort.InferenceSession(path) |
| return session |
|
|
|
|
| def load_torch_model(tier: str = "base", checkpoint: str = ""): |
| import torch |
| from .arch import PicoType, PicoTypeConfig |
| cfg = PicoTypeConfig(max_bytes=1024) |
| model = PicoType(cfg) |
| ckpt = torch.load(checkpoint, map_location="cpu") |
| model.load_state_dict(ckpt.get("model_state_dict", ckpt)) |
| model.eval() |
| return model, tier |
|
|
|
|
| def run_onnx(session, text: str, max_bytes: int = 1024) -> dict: |
| text_bytes = text.encode("utf-8")[:max_bytes] |
| ids = np.frombuffer(text_bytes, dtype=np.uint8).astype(np.int64) |
| seq_len = len(ids) |
| if seq_len > max_bytes: |
| ids = ids[:max_bytes] |
| seq_len = max_bytes |
| padded = np.zeros(max_bytes, dtype=np.int64) |
| padded[:seq_len] = ids |
| mask = np.zeros(max_bytes, dtype=np.bool_) |
| mask[:seq_len] = True |
|
|
| feed = { |
| "input_ids": padded[None, :], |
| "attention_mask": mask[None, :], |
| } |
| outs = session.run(None, feed) |
| result = {} |
| for name, logits in zip(ALL_HEADS, outs): |
| probs = _softmax(logits[0]) |
| if name == "risk": |
| result[name] = {LABEL_TABLES[name][i]: float(probs[i]) for i in range(len(probs))} |
| else: |
| idx = int(np.argmax(probs)) |
| label = LABEL_TABLES[name][idx] |
| result[name] = {"label": label, "confidence": float(probs[idx]), "index": idx} |
| return result |
|
|
|
|
| def _softmax(x): |
| e = np.exp(x - np.max(x)) |
| return e / e.sum() |
|
|
|
|
| def run_torch(model, tier: str, text: str, max_bytes: int = 1024) -> dict: |
| import torch |
| model = model[0] if isinstance(model, tuple) else model |
| text_bytes = text.encode("utf-8")[:max_bytes] |
| ids = torch.tensor([list(text_bytes)], dtype=torch.long) |
| mask = torch.ones(1, ids.shape[1], dtype=torch.bool) |
| with torch.no_grad(): |
| logits_dict = model(ids, mask) |
| out = {} |
| for head in ALL_HEADS: |
| logits = logits_dict[head] |
| tier_logits = logits[tier] if isinstance(logits, dict) else logits |
| if isinstance(tier_logits, dict): |
| tier_logits = tier_logits[tier] |
| probs = torch.softmax(tier_logits[0], dim=-1) |
| if head == "risk": |
| out[head] = {LABEL_TABLES[head][i]: float(probs[i]) for i in range(len(probs))} |
| else: |
| idx = int(torch.argmax(probs).item()) |
| out[head] = {"label": LABEL_TABLES[head][idx], "confidence": float(probs[idx]), "index": idx} |
| return out |
|
|
|
|
| def read_text(args) -> str: |
| if args.text: |
| return args.text |
| if args.file: |
| with open(args.file, "r", encoding="utf-8", errors="replace") as f: |
| return f.read() |
| if args.clip: |
| import subprocess |
| return subprocess.check_output(["pbpaste"], text=True) |
| if not sys.stdin.isatty(): |
| return sys.stdin.read() |
| raise ValueError("No input provided. Use --text, --file, --clip, or pipe content.") |
|
|
|
|
| def build_parser(): |
| p = argparse.ArgumentParser(prog="picotype", description="Classify content type and risk") |
| p.add_argument("--text", "-t", help="Text string to classify") |
| p.add_argument("--file", "-f", help="File path to classify") |
| p.add_argument("--clip", "-c", action="store_true", help="Classify clipboard content") |
| p.add_argument("--tier", default="base", choices=["tiny", "small", "base", "pro"], help="Model tier") |
| p.add_argument("--model-dir", default="checkpoints", help="Directory with ONNX models") |
| p.add_argument("--checkpoint", help="PyTorch checkpoint (fallback if no ONNX)") |
| p.add_argument("--pretty", "-p", action="store_true", help="Pretty-print JSON output") |
| return p |
|
|
|
|
| def main(): |
| args = build_parser().parse_args() |
| try: |
| text = read_text(args) |
| except ValueError as e: |
| print(e, file=sys.stderr) |
| sys.exit(1) |
|
|
| if not text.strip(): |
| print('{"error": "empty input"}') |
| sys.exit(0) |
|
|
| onnx_path = os.path.join(args.model_dir, f"picotype_{args.tier}.onnx") |
| if os.path.exists(onnx_path): |
| session = load_onnx_model(args.tier, args.model_dir) |
| result = run_onnx(session, text) |
| elif args.checkpoint: |
| model = load_torch_model(args.tier, args.checkpoint) |
| result = run_torch(model, args.tier, text) |
| else: |
| print(f"ONNX model not found at {onnx_path}. Use --checkpoint to use PyTorch.", file=sys.stderr) |
| sys.exit(1) |
|
|
| result["text_length"] = len(text) |
| result["tier"] = args.tier |
| indent = 2 if args.pretty else None |
| json.dump(result, sys.stdout, indent=indent, ensure_ascii=False) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|