EmpathRAG / src /models /annotate_corpus.py
Mukul Rayana
fix: scope models/ gitignore to root only, add src/models package, remove .claude
0e9c4c7
import sqlite3
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
from tqdm import tqdm
CHECKPOINT = "models/emotion_classifier"
DB_PATH = "data/indexes/metadata.db"
SAFETY_SCORE_MAP = {0: 0.0, 1: 0.0, 2: 0.3, 3: 0.7, 4: 1.0}
BATCH = 128
def annotate_corpus(checkpoint=CHECKPOINT, db_path=DB_PATH):
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base = AutoModelForSequenceClassification.from_pretrained("roberta-base", num_labels=5)
model = PeftModel.from_pretrained(base, checkpoint).eval()
conn = sqlite3.connect(db_path)
rows = conn.execute(
"SELECT id, text FROM chunks WHERE emotion_label = -1"
).fetchall()
print(f"Chunks to annotate: {len(rows)}")
for i in tqdm(range(0, len(rows), BATCH), desc="Annotating"):
batch = rows[i : i + BATCH]
ids = [r[0] for r in batch]
texts = [r[1] for r in batch]
enc = tokenizer(
texts, truncation=True, max_length=128, padding=True, return_tensors="pt"
)
with torch.no_grad():
logits = model(**enc).logits
labels = logits.argmax(-1).tolist()
for rid, lbl in zip(ids, labels):
score = SAFETY_SCORE_MAP[lbl]
conn.execute(
"UPDATE chunks SET emotion_label=?, safety_score=? WHERE id=?",
(lbl, score, rid),
)
conn.commit()
conn.close()
print("Annotation complete.")
if __name__ == "__main__":
annotate_corpus()