islomov's picture
Initial private upload
f361c60 verified
#!/usr/bin/env python3
"""Run curated transcript-normalization examples for rubai-corrector-transcript-uz."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
EXAMPLES = [
{
"category": "abbreviation_shorthand",
"input": "tlefon rqami",
"expected": "Telefon raqami",
},
{
"category": "abbreviation_shorthand",
"input": "telefon rqami qaysi",
"expected": "Telefon raqami qaysi",
},
{
"category": "apostrophe",
"input": "ozbekiston gozal mamlakat bolgan",
"expected": "O'zbekiston go'zal mamlakat bo'lgan",
},
{
"category": "apostrophe",
"input": "men ozim kordim",
"expected": "Men o'zim ko'rdim.",
},
{
"category": "ocr",
"input": "0zbekiston Respub1ikasi",
"expected": "O'zbekiston Respublikasi",
},
{
"category": "ocr",
"input": "5alom dostlar",
"expected": "Salom do'stlar",
},
{
"category": "numbers",
"input": "uchrashuv o'n beshinchi yanvar kuni",
"expected": "Uchrashuv 15-yanvar kuni",
},
{
"category": "numbers",
"input": "narxi yigirma besh ming so'm",
"expected": "Narxi 25 000 so'm",
},
{
"category": "mixed_uz_ru",
"input": "bugun yaxshi kun. segodnya xoroshiy den.",
"expected": "Bugun yaxshi kun. Сегодня хороший день.",
},
{
"category": "mixed_uz_ru",
"input": "men bozorga bordim. tam ya kupil xleb.",
"expected": "Men bozorga bordim. Там я купил хлеб.",
},
{
"category": "russian_only",
"input": "segodnya xoroshaya pogoda",
"expected": "Сегодня хорошая погода",
},
{
"category": "russian_only",
"input": "privet kak dela",
"expected": "Привет как дела",
},
{
"category": "mixed_script",
"input": "privet kak делa",
"expected": "Привет как дела",
},
{
"category": "mixed_script",
"input": "zaklad bersa keyin gaplashamiz",
"expected": "Заклад bersa keyin gaplashamiz",
},
{
"category": "display_cleanup",
"input": "mustahkamlik sinovida spark boshqa avtomobillarni ortda qoldirdi.",
"expected": "Mustahkamlik sinovida Spark boshqa avtomobillarni ortda qoldirdi.",
},
{
"category": "display_cleanup",
"input": "kadrlarda kranning mashina old oynasi ustiga qulaganligini ko'rish mumkin",
"expected": "Kadrlarda kranning mashina old oynasi ustiga qulaganligini ko'rish mumkin.",
},
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model-path",
type=Path,
default=Path(__file__).resolve().parent,
help="Path to the packaged model folder.",
)
parser.add_argument(
"--device",
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="Inference device, for example cuda:0 or cpu.",
)
parser.add_argument(
"--text",
type=str,
default=None,
help="Run a single custom input instead of the built-in example suite.",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=256,
help="Maximum generation length.",
)
parser.add_argument(
"--json",
action="store_true",
help="Print results as JSON.",
)
return parser.parse_args()
def load_model(model_path: Path, device: str):
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.to(device)
model.eval()
return tokenizer, model
def predict(texts: list[str], tokenizer, model, device: str, max_new_tokens: int) -> list[str]:
prompts = [f"correct: {text}" for text in texts]
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
with torch.inference_mode():
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
def main() -> int:
args = parse_args()
tokenizer, model = load_model(args.model_path, args.device)
if args.text is not None:
prediction = predict([args.text], tokenizer, model, args.device, args.max_new_tokens)[0]
if args.json:
print(json.dumps({"input": args.text, "prediction": prediction}, ensure_ascii=False, indent=2))
else:
print(f"Input: {args.text}")
print(f"Prediction: {prediction}")
return 0
predictions = predict(
[example["input"] for example in EXAMPLES],
tokenizer,
model,
args.device,
args.max_new_tokens,
)
results = []
for example, prediction in zip(EXAMPLES, predictions):
results.append(
{
"category": example["category"],
"input": example["input"],
"expected": example["expected"],
"prediction": prediction,
"exact_match": prediction == example["expected"],
}
)
if args.json:
print(json.dumps(results, ensure_ascii=False, indent=2))
return 0
print(f"Model: {args.model_path}")
print(f"Device: {args.device}")
print()
for row in results:
print(f"[{row['category']}]")
print(f"Input: {row['input']}")
print(f"Expected: {row['expected']}")
print(f"Prediction: {row['prediction']}")
print(f"Exact: {row['exact_match']}")
print()
return 0
if __name__ == "__main__":
raise SystemExit(main())