File size: 2,159 Bytes
f29d474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import json


def evaluate_perplexity(model_path: str, eval_file: str) -> float:
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
    with open(eval_file, 'r', encoding='utf-8') as f:
        text = f.read()
    enc = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        loss = model(**{k: v.to(model.device) for k, v in enc.items()}, labels=enc["input_ids"].to(model.device)).loss
    ppl = torch.exp(loss).item()
    print(f"Perplexity: {ppl:.2f}")
    return ppl


if __name__ == "__main__":
    model_path = os.getenv("MODEL_PATH", "./models/mistral-finetuned-mk")
    eval_file = os.getenv("EVAL_FILE", "data/cleaned/mk_combined_data.txt")
    evaluate_perplexity(model_path, eval_file)

    # Simple QA accuracy on small Macedonian eval
    qa_file = os.getenv("QA_EVAL_FILE", "data/eval/mk_eval.jsonl")
    if os.path.exists(qa_file):
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
        correct = 0
        total = 0
        with open(qa_file, 'r', encoding='utf-8') as f:
            for line in f:
                item = json.loads(line)
                q = item["question"]
                gt = item["answer"].strip()
                prompt = f"Прашање: {q}\nОдговор:"
                inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
                with torch.no_grad():
                    out = model.generate(**inputs, max_new_tokens=64)
                pred = tokenizer.decode(out[0], skip_special_tokens=True)
                # naive check
                is_correct = gt.lower() in pred.lower()
                correct += 1 if is_correct else 0
                total += 1
        if total:
            acc = correct / total * 100
            print(f"QA accuracy on mk_eval.jsonl: {acc:.1f}% ({correct}/{total})")