File size: 1,843 Bytes
34052ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import json
import os
from typing import Dict, Any
from dataclasses import dataclass
from enum import Enum
from datetime import datetime
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from datasets import load_dataset
import traceback
from src.envs import API, OWNER, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, RESULTS_REPO



def evaluate_tunisian_corpus_coverage(model, tokenizer, device):
    """Evaluate model's coverage on Tunisian Dialect Corpus"""
    try:
        dataset = load_dataset("arbml/Tunisian_Dialect_Corpus", split="train")
        
        def preprocess(examples):
            # print("Tunisian Corpus preprocess exemples -------------",examples)
            # Use 'Tweet' field as per dataset structure
            return tokenizer(
                examples['Tweet'], 
                padding=False,  # We don't need padding for token coverage
                truncation=False,  # Don't truncate long sequences
                max_length=None  # Let tokenizer handle the length
            )
        
        dataset = dataset.map(preprocess, batched=True)
        
        total_tokens = 0
        covered_tokens = 0
        
        for example in dataset:
            input_ids = example['input_ids']
            
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            total_tokens += len(tokens)
            covered_tokens += len([t for t in tokens if t != tokenizer.unk_token])
        
        coverage = covered_tokens / total_tokens if total_tokens > 0 else 0
        print(f"Tunisian Corpus Coverage: {coverage:.2%}")
        return {"arbml/Tunisian_Dialect_Corpus": coverage}
    except Exception as e:
        print(f"Error in Tunisian Corpus evaluation: {str(e)}")
        print(f"Full traceback: {traceback.format_exc()}")
        raise e