#!/usr/bin/env python3 """Run curated transcript-normalization examples for rubai-corrector-transcript-uz.""" from __future__ import annotations import argparse import json from pathlib import Path import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer EXAMPLES = [ { "category": "abbreviation_shorthand", "input": "tlefon rqami", "expected": "Telefon raqami", }, { "category": "abbreviation_shorthand", "input": "telefon rqami qaysi", "expected": "Telefon raqami qaysi", }, { "category": "apostrophe", "input": "ozbekiston gozal mamlakat bolgan", "expected": "O'zbekiston go'zal mamlakat bo'lgan", }, { "category": "apostrophe", "input": "men ozim kordim", "expected": "Men o'zim ko'rdim.", }, { "category": "ocr", "input": "0zbekiston Respub1ikasi", "expected": "O'zbekiston Respublikasi", }, { "category": "ocr", "input": "5alom dostlar", "expected": "Salom do'stlar", }, { "category": "numbers", "input": "uchrashuv o'n beshinchi yanvar kuni", "expected": "Uchrashuv 15-yanvar kuni", }, { "category": "numbers", "input": "narxi yigirma besh ming so'm", "expected": "Narxi 25 000 so'm", }, { "category": "mixed_uz_ru", "input": "bugun yaxshi kun. segodnya xoroshiy den.", "expected": "Bugun yaxshi kun. Сегодня хороший день.", }, { "category": "mixed_uz_ru", "input": "men bozorga bordim. tam ya kupil xleb.", "expected": "Men bozorga bordim. Там я купил хлеб.", }, { "category": "russian_only", "input": "segodnya xoroshaya pogoda", "expected": "Сегодня хорошая погода", }, { "category": "russian_only", "input": "privet kak dela", "expected": "Привет как дела", }, { "category": "mixed_script", "input": "privet kak делa", "expected": "Привет как дела", }, { "category": "mixed_script", "input": "zaklad bersa keyin gaplashamiz", "expected": "Заклад bersa keyin gaplashamiz", }, { "category": "display_cleanup", "input": "mustahkamlik sinovida spark boshqa avtomobillarni ortda qoldirdi.", "expected": "Mustahkamlik sinovida Spark boshqa avtomobillarni ortda qoldirdi.", }, { "category": "display_cleanup", "input": "kadrlarda kranning mashina old oynasi ustiga qulaganligini ko'rish mumkin", "expected": "Kadrlarda kranning mashina old oynasi ustiga qulaganligini ko'rish mumkin.", }, ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--model-path", type=Path, default=Path(__file__).resolve().parent, help="Path to the packaged model folder.", ) parser.add_argument( "--device", default="cuda:0" if torch.cuda.is_available() else "cpu", help="Inference device, for example cuda:0 or cpu.", ) parser.add_argument( "--text", type=str, default=None, help="Run a single custom input instead of the built-in example suite.", ) parser.add_argument( "--max-new-tokens", type=int, default=256, help="Maximum generation length.", ) parser.add_argument( "--json", action="store_true", help="Print results as JSON.", ) return parser.parse_args() def load_model(model_path: Path, device: str): tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForSeq2SeqLM.from_pretrained(model_path) model.to(device) model.eval() return tokenizer, model def predict(texts: list[str], tokenizer, model, device: str, max_new_tokens: int) -> list[str]: prompts = [f"correct: {text}" for text in texts] inputs = tokenizer(prompts, return_tensors="pt", padding=True) inputs = {name: tensor.to(device) for name, tensor in inputs.items()} with torch.inference_mode(): output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) return tokenizer.batch_decode(output_ids, skip_special_tokens=True) def main() -> int: args = parse_args() tokenizer, model = load_model(args.model_path, args.device) if args.text is not None: prediction = predict([args.text], tokenizer, model, args.device, args.max_new_tokens)[0] if args.json: print(json.dumps({"input": args.text, "prediction": prediction}, ensure_ascii=False, indent=2)) else: print(f"Input: {args.text}") print(f"Prediction: {prediction}") return 0 predictions = predict( [example["input"] for example in EXAMPLES], tokenizer, model, args.device, args.max_new_tokens, ) results = [] for example, prediction in zip(EXAMPLES, predictions): results.append( { "category": example["category"], "input": example["input"], "expected": example["expected"], "prediction": prediction, "exact_match": prediction == example["expected"], } ) if args.json: print(json.dumps(results, ensure_ascii=False, indent=2)) return 0 print(f"Model: {args.model_path}") print(f"Device: {args.device}") print() for row in results: print(f"[{row['category']}]") print(f"Input: {row['input']}") print(f"Expected: {row['expected']}") print(f"Prediction: {row['prediction']}") print(f"Exact: {row['exact_match']}") print() return 0 if __name__ == "__main__": raise SystemExit(main())