COPAL / predict.py
balastml's picture
Upload 11 files
5ff308a verified
Raw
History Blame Contribute Delete
1.88 kB
"""
T5-Base Finetuned CEFR Level Prediction Model
by EmreKalkan
"""
import argparse
import torch
from transformers import T5TokenizerFast, T5ForConditionalGeneration
MODEL_DIR = "." # Repo
TASK_PREFIX = "classify cefr: " # DONT CHANGE IT. That is a training constant.
MAX_LEN = 96
LEVELS = ["a1", "a2", "b1", "b2", "c1"]
class CefrClassifier:
def __init__(self, model_dir=MODEL_DIR):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tok = T5TokenizerFast.from_pretrained(model_dir, model_max_length=MAX_LEN)
self.model = T5ForConditionalGeneration.from_pretrained(model_dir).to(self.device).eval()
@torch.no_grad() #0grad
def predict(self, sentences):
single = isinstance(sentences, str)
if single:
sentences = [sentences]
enc = self.tok([TASK_PREFIX + s for s in sentences], return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN).to(self.device)
gen = self.model.generate(**enc, max_length=8, num_beams=1)
out = [t.strip().lower() for t in self.tok.batch_decode(gen, skip_special_tokens=True)]
return out[0] if single else out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--text", default=None)
ap.add_argument("--model_dir", default=MODEL_DIR)
args = ap.parse_args()
clf = CefrClassifier(args.model_dir)
if args.text == 1:
print(f"{args.text}\n -> CEFR: {clf.predict(args.text).upper()}")
return
print("CEFR Prediction (for quit: q)\n")
while True:
try:
t = input("Sentence> ").strip()
except (EOFError, KeyboardInterrupt):
break
if t.lower() in {"q", "quit", "exit"}:
break
if t:
print(f" -> CEFR: {clf.predict(t).upper()}\n")
if __name__ == "__main__":
main()