GautamKishore's picture
Upload folder using huggingface_hub
c5f9f0d verified
Raw
History Blame Contribute Delete
5.45 kB
"""pico-type CLI: classify content from stdin/file/clipboard."""
from __future__ import annotations
import argparse
import json
import os
import sys
import numpy as np
from .labels import ALL_HEADS, COARSE_LABELS, MODALITY_LABELS, SUBTYPE_LABELS, CODE_LANG_LABELS, TEXT_LANG_LABELS, FILE_MIME_LABELS, RISK_LABELS
LABEL_TABLES = {
"coarse": COARSE_LABELS,
"modality": MODALITY_LABELS,
"subtype": SUBTYPE_LABELS,
"code_lang": CODE_LANG_LABELS,
"text_lang": TEXT_LANG_LABELS,
"file_mime": FILE_MIME_LABELS,
"risk": RISK_LABELS,
}
def load_onnx_model(tier: str = "base", model_dir: str = "checkpoints"):
import onnxruntime as ort
path = os.path.join(model_dir, f"picotype_{tier}.onnx")
if not os.path.exists(path):
raise FileNotFoundError(f"ONNX model not found: {path}")
session = ort.InferenceSession(path)
return session
def load_torch_model(tier: str = "base", checkpoint: str = ""):
import torch
from .arch import PicoType, PicoTypeConfig
cfg = PicoTypeConfig(max_bytes=1024)
model = PicoType(cfg)
ckpt = torch.load(checkpoint, map_location="cpu")
model.load_state_dict(ckpt.get("model_state_dict", ckpt))
model.eval()
return model, tier
def run_onnx(session, text: str, max_bytes: int = 1024) -> dict:
text_bytes = text.encode("utf-8")[:max_bytes]
ids = np.frombuffer(text_bytes, dtype=np.uint8).astype(np.int64)
seq_len = len(ids)
if seq_len > max_bytes:
ids = ids[:max_bytes]
seq_len = max_bytes
padded = np.zeros(max_bytes, dtype=np.int64)
padded[:seq_len] = ids
mask = np.zeros(max_bytes, dtype=np.bool_)
mask[:seq_len] = True
feed = {
"input_ids": padded[None, :],
"attention_mask": mask[None, :],
}
outs = session.run(None, feed)
result = {}
for name, logits in zip(ALL_HEADS, outs):
probs = _softmax(logits[0])
if name == "risk":
result[name] = {LABEL_TABLES[name][i]: float(probs[i]) for i in range(len(probs))}
else:
idx = int(np.argmax(probs))
label = LABEL_TABLES[name][idx]
result[name] = {"label": label, "confidence": float(probs[idx]), "index": idx}
return result
def _softmax(x):
e = np.exp(x - np.max(x))
return e / e.sum()
def run_torch(model, tier: str, text: str, max_bytes: int = 1024) -> dict:
import torch
model = model[0] if isinstance(model, tuple) else model
text_bytes = text.encode("utf-8")[:max_bytes]
ids = torch.tensor([list(text_bytes)], dtype=torch.long)
mask = torch.ones(1, ids.shape[1], dtype=torch.bool)
with torch.no_grad():
logits_dict = model(ids, mask)
out = {}
for head in ALL_HEADS:
logits = logits_dict[head]
tier_logits = logits[tier] if isinstance(logits, dict) else logits
if isinstance(tier_logits, dict):
tier_logits = tier_logits[tier]
probs = torch.softmax(tier_logits[0], dim=-1)
if head == "risk":
out[head] = {LABEL_TABLES[head][i]: float(probs[i]) for i in range(len(probs))}
else:
idx = int(torch.argmax(probs).item())
out[head] = {"label": LABEL_TABLES[head][idx], "confidence": float(probs[idx]), "index": idx}
return out
def read_text(args) -> str:
if args.text:
return args.text
if args.file:
with open(args.file, "r", encoding="utf-8", errors="replace") as f:
return f.read()
if args.clip:
import subprocess
return subprocess.check_output(["pbpaste"], text=True)
if not sys.stdin.isatty():
return sys.stdin.read()
raise ValueError("No input provided. Use --text, --file, --clip, or pipe content.")
def build_parser():
p = argparse.ArgumentParser(prog="picotype", description="Classify content type and risk")
p.add_argument("--text", "-t", help="Text string to classify")
p.add_argument("--file", "-f", help="File path to classify")
p.add_argument("--clip", "-c", action="store_true", help="Classify clipboard content")
p.add_argument("--tier", default="base", choices=["tiny", "small", "base", "pro"], help="Model tier")
p.add_argument("--model-dir", default="checkpoints", help="Directory with ONNX models")
p.add_argument("--checkpoint", help="PyTorch checkpoint (fallback if no ONNX)")
p.add_argument("--pretty", "-p", action="store_true", help="Pretty-print JSON output")
return p
def main():
args = build_parser().parse_args()
try:
text = read_text(args)
except ValueError as e:
print(e, file=sys.stderr)
sys.exit(1)
if not text.strip():
print('{"error": "empty input"}')
sys.exit(0)
onnx_path = os.path.join(args.model_dir, f"picotype_{args.tier}.onnx")
if os.path.exists(onnx_path):
session = load_onnx_model(args.tier, args.model_dir)
result = run_onnx(session, text)
elif args.checkpoint:
model = load_torch_model(args.tier, args.checkpoint)
result = run_torch(model, args.tier, text)
else:
print(f"ONNX model not found at {onnx_path}. Use --checkpoint to use PyTorch.", file=sys.stderr)
sys.exit(1)
result["text_length"] = len(text)
result["tier"] = args.tier
indent = 2 if args.pretty else None
json.dump(result, sys.stdout, indent=indent, ensure_ascii=False)
print()
if __name__ == "__main__":
main()