""" T5-Base Finetuned CEFR Level Prediction Model by EmreKalkan """ import argparse import torch from transformers import T5TokenizerFast, T5ForConditionalGeneration MODEL_DIR = "." # Repo TASK_PREFIX = "classify cefr: " # DONT CHANGE IT. That is a training constant. MAX_LEN = 96 LEVELS = ["a1", "a2", "b1", "b2", "c1"] class CefrClassifier: def __init__(self, model_dir=MODEL_DIR): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tok = T5TokenizerFast.from_pretrained(model_dir, model_max_length=MAX_LEN) self.model = T5ForConditionalGeneration.from_pretrained(model_dir).to(self.device).eval() @torch.no_grad() #0grad def predict(self, sentences): single = isinstance(sentences, str) if single: sentences = [sentences] enc = self.tok([TASK_PREFIX + s for s in sentences], return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN).to(self.device) gen = self.model.generate(**enc, max_length=8, num_beams=1) out = [t.strip().lower() for t in self.tok.batch_decode(gen, skip_special_tokens=True)] return out[0] if single else out def main(): ap = argparse.ArgumentParser() ap.add_argument("--text", default=None) ap.add_argument("--model_dir", default=MODEL_DIR) args = ap.parse_args() clf = CefrClassifier(args.model_dir) if args.text == 1: print(f"{args.text}\n -> CEFR: {clf.predict(args.text).upper()}") return print("CEFR Prediction (for quit: q)\n") while True: try: t = input("Sentence> ").strip() except (EOFError, KeyboardInterrupt): break if t.lower() in {"q", "quit", "exit"}: break if t: print(f" -> CEFR: {clf.predict(t).upper()}\n") if __name__ == "__main__": main()