File size: 1,544 Bytes
0e9c4c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import sqlite3
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
from tqdm import tqdm

CHECKPOINT = "models/emotion_classifier"
DB_PATH = "data/indexes/metadata.db"
SAFETY_SCORE_MAP = {0: 0.0, 1: 0.0, 2: 0.3, 3: 0.7, 4: 1.0}
BATCH = 128


def annotate_corpus(checkpoint=CHECKPOINT, db_path=DB_PATH):
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    base = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=5)
    model = PeftModel.from_pretrained(base, checkpoint).eval()

    conn = sqlite3.connect(db_path)
    rows = conn.execute(
        "SELECT id, text FROM chunks WHERE emotion_label = -1"
    ).fetchall()
    print(f"Chunks to annotate: {len(rows)}")

    for i in tqdm(range(0, len(rows), BATCH), desc="Annotating"):
        batch = rows[i : i + BATCH]
        ids = [r[0] for r in batch]
        texts = [r[1] for r in batch]
        enc = tokenizer(
            texts, truncation=True, max_length=128, padding=True, return_tensors="pt"
        )
        with torch.no_grad():
            logits = model(**enc).logits
        labels = logits.argmax(-1).tolist()
        for rid, lbl in zip(ids, labels):
            score = SAFETY_SCORE_MAP[lbl]
            conn.execute(
                "UPDATE chunks SET emotion_label=?, safety_score=? WHERE id=?",
                (lbl, score, rid),
            )
        conn.commit()

    conn.close()
    print("Annotation complete.")


if __name__ == "__main__":
    annotate_corpus()