File size: 1,881 Bytes
5ff308a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
T5-Base Finetuned CEFR Level Prediction Model
by EmreKalkan
"""
import argparse
import torch
from transformers import T5TokenizerFast, T5ForConditionalGeneration

MODEL_DIR = "."                      # Repo 
TASK_PREFIX = "classify cefr: "      # DONT CHANGE IT. That is a training constant.
MAX_LEN = 96
LEVELS = ["a1", "a2", "b1", "b2", "c1"]


class CefrClassifier:
    def __init__(self, model_dir=MODEL_DIR):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tok = T5TokenizerFast.from_pretrained(model_dir, model_max_length=MAX_LEN)
        self.model = T5ForConditionalGeneration.from_pretrained(model_dir).to(self.device).eval()

    @torch.no_grad()  #0grad
    def predict(self, sentences):
        single = isinstance(sentences, str)
        if single:
            sentences = [sentences]
        enc = self.tok([TASK_PREFIX + s for s in sentences], return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN).to(self.device)
        gen = self.model.generate(**enc, max_length=8, num_beams=1)
        out = [t.strip().lower() for t in self.tok.batch_decode(gen, skip_special_tokens=True)]
        return out[0] if single else out


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--text", default=None)
    ap.add_argument("--model_dir", default=MODEL_DIR)
    args = ap.parse_args()

    clf = CefrClassifier(args.model_dir)
    if args.text == 1:
        print(f"{args.text}\n  -> CEFR: {clf.predict(args.text).upper()}")
        return

    print("CEFR Prediction (for quit: q)\n")
    while True:
        try:
            t = input("Sentence> ").strip()
        except (EOFError, KeyboardInterrupt):
            break
        if t.lower() in {"q", "quit", "exit"}:
            break
        if t:
            print(f"  -> CEFR: {clf.predict(t).upper()}\n")


if __name__ == "__main__":
    main()