{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "b41fd227", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "!python settings.py" ] }, { "cell_type": "code", "execution_count": null, "id": "030016c2", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import pandas as pd\n", "from pprint import pprint\n", "from tqdm.autonotebook import tqdm\n", "\n", "from sentence_transformers import SentenceTransformer\n", "\n", "from mteb import MTEB\n", "from mteb.abstasks.TaskMetadata import TaskMetadata\n", "from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval\n", "\n", "from settings import MODEL_NAME, OUTPUT_DIR, DEVICE, BATCH_SIZE\n", "os.environ['WANDB_DISABLED'] = 'true'" ] }, { "cell_type": "code", "execution_count": 4, "id": "dd3f53a3", "metadata": {}, "outputs": [], "source": [ "data = {\n", " 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n", " 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n", " 'test' : pd.read_parquet('data/processed/test_data.parquet')\n", "}\n", "for split in ['train', 'test']:\n", " data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n", " data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())" ] }, { "cell_type": "code", "execution_count": 5, "id": "41ffd5ce", "metadata": {}, "outputs": [], "source": [ "class BKAILegalDocRetrievalTask(AbsTaskRetrieval):\n", " # Metadata definition used by MTEB benchmark\n", " metadata = TaskMetadata(name='BKAILegalDocRetrieval',\n", " description='',\n", " reference='https://github.com/embeddings-benchmark/mteb/blob/main/docs/adding_a_dataset.md',\n", " type='Retrieval',\n", " category='s2p',\n", " modalities=['text'],\n", " eval_splits=['test'],\n", " eval_langs=['vi'],\n", " main_score='ndcg_at_10',\n", " other_scores=['recall_at_10', 'precision_at_10', 'map'],\n", " dataset={\n", " 'path' : 'data',\n", " 'revision': 'd4c5a8ba10ae71224752c727094ac4c46947fa29',\n", " },\n", " date=('2012-01-01', '2020-01-01'),\n", " form='Written',\n", " domains=['Academic', 'Non-fiction'],\n", " task_subtypes=['Scientific Reranking'],\n", " license='cc-by-nc-4.0',\n", " annotations_creators='derived',\n", " dialect=[],\n", " text_creation='found',\n", " bibtex_citation=''\n", " )\n", "\n", " data_loaded = True # Flag\n", "\n", " def __init__(self, **kwargs):\n", " super().__init__(**kwargs)\n", "\n", " self.corpus = {}\n", " self.queries = {}\n", " self.relevant_docs = {}\n", "\n", " shared_corpus = {}\n", " for _, row in data['corpus'].iterrows():\n", " shared_corpus[f\"c{row['cid']}\"] = {\n", " 'text': row['text'],\n", " '_id' : row['cid']\n", " }\n", " \n", " for split in ['train', 'test']:\n", " self.corpus[split] = shared_corpus\n", " self.queries[split] = {}\n", " self.relevant_docs[split] = {}\n", "\n", " for split in ['train', 'test']:\n", " for _, row in data[split].iterrows():\n", " qid, cids = row['qid'], row['cid']\n", " \n", " qid_str = f'q{qid}'\n", " cids_str = [f'c{cid}' for cid in cids]\n", " \n", " self.queries[split][qid_str] = row['question']\n", " \n", " if qid_str not in self.relevant_docs[split]:\n", " self.relevant_docs[split][qid_str] = {}\n", " \n", " for cid_str in cids_str:\n", " self.relevant_docs[split][qid_str][cid_str] = 1\n", " \n", " self.data_loaded = True" ] }, { "cell_type": "code", "execution_count": 6, "id": "8c212fe9", "metadata": {}, "outputs": [], "source": [ "fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)" ] }, { "cell_type": "code", "execution_count": 7, "id": "aae09322", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The `batch_size` argument is deprecated and will be removed in the next release. Please use `encode_kwargs = {'batch_size': ...}` to set the batch size instead.\n" ] }, { "data": { "text/html": [ "
───────────────────────────────────────────────── Selected tasks ─────────────────────────────────────────────────\n", "\n" ], "text/plain": [ "\u001b[38;5;235m───────────────────────────────────────────────── \u001b[0m\u001b[1mSelected tasks \u001b[0m\u001b[38;5;235m ─────────────────────────────────────────────────\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Retrieval\n",
"\n"
],
"text/plain": [
"\u001b[1mRetrieval\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" - BKAILegalDocRetrieval, s2p\n",
"\n"
],
"text/plain": [
" - BKAILegalDocRetrieval, \u001b[3;38;5;241ms2p\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"\n"
],
"text/plain": [
"\n",
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "53778754caf4456f8e140cfa58b60709",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/233 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f9b27ae885fc46ad83f332f222a76381",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/391 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b6e38b0d54a4b429db05158604d24a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/391 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "20ec5df7261c43a7921abc968cc5e3a6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/391 [00:02, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5f365f06d3de4becb965adb801aeee60",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/391 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a43b764ac83e43aeb754c1e60771fd5c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/391 [00:02, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ae46c8f76bc64eac8ca475d13f312875",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/91 [00:02, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[TaskResult(task_name=BKAILegalDocRetrieval, scores=...)]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"custom_task = BKAILegalDocRetrievalTask()\n",
"evaluation = MTEB(tasks=[custom_task])\n",
"evaluation.run(fine_tuned_model, batch_size=BATCH_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "004e6930",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Main Evaluation Metrics (Top-K = 10):\n",
"{'evaluation_time (s)': 3061.7869832515717,\n",
" 'main_score': 0.60389,\n",
" 'mrr@10': 0.555102,\n",
" 'precision@10': 0.08587,\n",
" 'recall@10': 0.79407}\n"
]
}
],
"source": [
"file_path = f\"results/no_model_name_available/no_revision_available/BKAILegalDocRetrieval.json\"\n",
"\n",
"with open(file_path, 'r', encoding='utf-8') as f:\n",
" eval_data = json.load(f)\n",
"\n",
"scores = eval_data[\"scores\"][\"test\"][0]\n",
"main_metrics = {\n",
" 'main_score' : scores.get('ndcg_at_10'),\n",
" 'recall@10' : scores.get('recall_at_10'),\n",
" 'precision@10' : scores.get('precision_at_10'),\n",
" 'mrr@10' : scores.get('mrr_at_10'),\n",
" 'evaluation_time (s)': eval_data.get('evaluation_time')\n",
"}\n",
"\n",
"print('Main Evaluation Metrics (Top-K = 10):')\n",
"pprint(main_metrics)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "672ebc32",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Evaluation Scores by K:\n",
"metric map mrr ndcg precision recall\n",
"k \n",
"1 0.4033 0.4242 0.4242 0.4242 0.4033\n",
"3 0.5031 0.5247 0.5394 0.2215 0.6232\n",
"5 0.5230 0.5434 0.5739 0.1512 0.7047\n",
"10 0.5361 0.5551 0.6039 0.0859 0.7941\n",
"20 0.5414 0.5596 0.6216 0.0469 0.8611\n",
"100 0.5442 0.5617 0.6389 0.0104 0.9480\n",
"1000 0.5444 0.5619 0.6444 0.0011 0.9879\n"
]
}
],
"source": [
"metrics = {k: v for k, v in scores.items() if '_at_' in k and not k.startswith('nauc')}\n",
"\n",
"parsed_metrics = []\n",
"for key, value in metrics.items():\n",
" metric, at_k = key.split('_at_')\n",
" parsed_metrics.append({'metric': metric, 'k': int(at_k), 'score': value})\n",
"\n",
"df_metrics = pd.DataFrame(parsed_metrics).pivot(index='k', columns='metric', values='score')\n",
"df_metrics = df_metrics.sort_index()\n",
"\n",
"print(\"\\nEvaluation Scores by K:\")\n",
"print(df_metrics.round(4))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "legal_doc_retrieval",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}