#!/usr/bin/env python3 """Run example inference for rubai-corrector-base.""" from __future__ import annotations import argparse import json from pathlib import Path import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer EXAMPLES = [ { "category": "abbreviation", "input": "telefon rqami qaysi", "expected": "Telefon raqami qaysi", }, { "category": "apostrophe", "input": "men ozim kordim", "expected": "Men o'zim ko'rdim", }, { "category": "apostrophe", "input": "togri yoldan boring", "expected": "To'g'ri yo'ldan boring", }, { "category": "ocr", "input": "rnen universitetda oqiyrnan", "expected": "Men universitetda o'qiyman", }, { "category": "ocr", "input": "bu juda rnuhirn masala", "expected": "Bu juda muhim masala", }, { "category": "numbers", "input": "narxi yigirma besh ming so'm", "expected": "Narxi 25 000 so'm", }, { "category": "numbers", "input": "uchrashuv o'n beshinchi yanvar kuni", "expected": "Uchrashuv 15-yanvar kuni", }, { "category": "mixed_uz_ru", "input": "men segodnya bozorga bordim", "expected": "Men сегодня bozorga bordim", }, { "category": "mixed_script", "input": "privet kak делa", "expected": "Привет как дела", }, { "category": "uzbek_cleanup", "input": "xamma narsa tayyor", "expected": "Hamma narsa tayyor", }, ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--model-path", type=Path, default=Path(__file__).resolve().parent, help="Path to the packaged model folder.", ) parser.add_argument( "--device", default="cuda:0" if torch.cuda.is_available() else "cpu", help="Inference device, for example cuda:0 or cpu.", ) parser.add_argument( "--text", type=str, default=None, help="Run a single custom input instead of the built-in example suite.", ) parser.add_argument( "--max-new-tokens", type=int, default=256, help="Maximum generation length.", ) parser.add_argument( "--json", action="store_true", help="Print results as JSON.", ) return parser.parse_args() def load_model(model_path: Path, device: str): tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSeq2SeqLM.from_pretrained(model_path) model.to(device) model.eval() return tokenizer, model def predict(texts: list[str], tokenizer, model, device: str, max_new_tokens: int) -> list[str]: prompts = [f"correct: {text}" for text in texts] inputs = tokenizer(prompts, return_tensors="pt", padding=True) inputs = {name: tensor.to(device) for name, tensor in inputs.items()} with torch.inference_mode(): output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) return tokenizer.batch_decode(output_ids, skip_special_tokens=True) def main() -> int: args = parse_args() tokenizer, model = load_model(args.model_path, args.device) if args.text is not None: prediction = predict([args.text], tokenizer, model, args.device, args.max_new_tokens)[0] if args.json: print(json.dumps({"input": args.text, "prediction": prediction}, ensure_ascii=False, indent=2)) else: print(f"Input: {args.text}") print(f"Prediction: {prediction}") return 0 predictions = predict( [example["input"] for example in EXAMPLES], tokenizer, model, args.device, args.max_new_tokens, ) results = [] for example, prediction in zip(EXAMPLES, predictions): results.append( { "category": example["category"], "input": example["input"], "expected": example["expected"], "prediction": prediction, "exact_match": prediction == example["expected"], } ) if args.json: print(json.dumps(results, ensure_ascii=False, indent=2)) return 0 print(f"Model: {args.model_path}") print(f"Device: {args.device}") print() for row in results: print(f"[{row['category']}]") print(f"Input: {row['input']}") print(f"Expected: {row['expected']}") print(f"Prediction: {row['prediction']}") print(f"Exact: {row['exact_match']}") print() return 0 if __name__ == "__main__": raise SystemExit(main())