| |
| """Run example inference for rubai-corrector-base.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
| EXAMPLES = [ |
| { |
| "category": "abbreviation", |
| "input": "telefon rqami qaysi", |
| "expected": "Telefon raqami qaysi", |
| }, |
| { |
| "category": "apostrophe", |
| "input": "men ozim kordim", |
| "expected": "Men o'zim ko'rdim", |
| }, |
| { |
| "category": "apostrophe", |
| "input": "togri yoldan boring", |
| "expected": "To'g'ri yo'ldan boring", |
| }, |
| { |
| "category": "ocr", |
| "input": "rnen universitetda oqiyrnan", |
| "expected": "Men universitetda o'qiyman", |
| }, |
| { |
| "category": "ocr", |
| "input": "bu juda rnuhirn masala", |
| "expected": "Bu juda muhim masala", |
| }, |
| { |
| "category": "numbers", |
| "input": "narxi yigirma besh ming so'm", |
| "expected": "Narxi 25 000 so'm", |
| }, |
| { |
| "category": "numbers", |
| "input": "uchrashuv o'n beshinchi yanvar kuni", |
| "expected": "Uchrashuv 15-yanvar kuni", |
| }, |
| { |
| "category": "mixed_uz_ru", |
| "input": "men segodnya bozorga bordim", |
| "expected": "Men сегодня bozorga bordim", |
| }, |
| { |
| "category": "mixed_script", |
| "input": "privet kak делa", |
| "expected": "Привет как дела", |
| }, |
| { |
| "category": "uzbek_cleanup", |
| "input": "xamma narsa tayyor", |
| "expected": "Hamma narsa tayyor", |
| }, |
| ] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--model-path", |
| type=Path, |
| default=Path(__file__).resolve().parent, |
| help="Path to the packaged model folder.", |
| ) |
| parser.add_argument( |
| "--device", |
| default="cuda:0" if torch.cuda.is_available() else "cpu", |
| help="Inference device, for example cuda:0 or cpu.", |
| ) |
| parser.add_argument( |
| "--text", |
| type=str, |
| default=None, |
| help="Run a single custom input instead of the built-in example suite.", |
| ) |
| parser.add_argument( |
| "--max-new-tokens", |
| type=int, |
| default=256, |
| help="Maximum generation length.", |
| ) |
| parser.add_argument( |
| "--json", |
| action="store_true", |
| help="Print results as JSON.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def load_model(model_path: Path, device: str): |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
| model.to(device) |
| model.eval() |
| return tokenizer, model |
|
|
|
|
| def predict(texts: list[str], tokenizer, model, device: str, max_new_tokens: int) -> list[str]: |
| prompts = [f"correct: {text}" for text in texts] |
| inputs = tokenizer(prompts, return_tensors="pt", padding=True) |
| inputs = {name: tensor.to(device) for name, tensor in inputs.items()} |
| with torch.inference_mode(): |
| output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) |
| return tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| tokenizer, model = load_model(args.model_path, args.device) |
|
|
| if args.text is not None: |
| prediction = predict([args.text], tokenizer, model, args.device, args.max_new_tokens)[0] |
| if args.json: |
| print(json.dumps({"input": args.text, "prediction": prediction}, ensure_ascii=False, indent=2)) |
| else: |
| print(f"Input: {args.text}") |
| print(f"Prediction: {prediction}") |
| return 0 |
|
|
| predictions = predict( |
| [example["input"] for example in EXAMPLES], |
| tokenizer, |
| model, |
| args.device, |
| args.max_new_tokens, |
| ) |
|
|
| results = [] |
| for example, prediction in zip(EXAMPLES, predictions): |
| results.append( |
| { |
| "category": example["category"], |
| "input": example["input"], |
| "expected": example["expected"], |
| "prediction": prediction, |
| "exact_match": prediction == example["expected"], |
| } |
| ) |
|
|
| if args.json: |
| print(json.dumps(results, ensure_ascii=False, indent=2)) |
| return 0 |
|
|
| print(f"Model: {args.model_path}") |
| print(f"Device: {args.device}") |
| print() |
| for row in results: |
| print(f"[{row['category']}]") |
| print(f"Input: {row['input']}") |
| print(f"Expected: {row['expected']}") |
| print(f"Prediction: {row['prediction']}") |
| print(f"Exact: {row['exact_match']}") |
| print() |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|