irinaqqq commited on
Commit
87ce102
·
1 Parent(s): 6a02b16
.gitignore CHANGED
@@ -1,2 +1,5 @@
1
  .vscode
2
  venv*
 
 
 
 
1
  .vscode
2
  venv*
3
+ .env
4
+ __pycache__/**
5
+ *.pyc
site/backend/.gitignore CHANGED
@@ -1,2 +1,5 @@
1
  .vscode
2
  venv*
 
 
 
 
1
  .vscode
2
  venv*
3
+ .env
4
+ __pycache__/
5
+ *.pyc
site/backend/__pycache__/app.cpython-311.pyc DELETED
Binary file (16.9 kB)
 
src/__pycache__/build_index.cpython-311.pyc DELETED
Binary file (5.5 kB)
 
src/__pycache__/data_io.cpython-311.pyc DELETED
Binary file (4.54 kB)
 
src/__pycache__/demo_cli.cpython-311.pyc DELETED
Binary file (4.81 kB)
 
src/__pycache__/evaluate.cpython-311.pyc DELETED
Binary file (6.52 kB)
 
src/__pycache__/train_biencoder.cpython-311.pyc DELETED
Binary file (3.21 kB)
 
src/__pycache__/validate.cpython-311.pyc DELETED
Binary file (4.17 kB)
 
src/demo_cli.py CHANGED
@@ -2,7 +2,7 @@ from pathlib import Path
2
  import numpy as np
3
  import faiss
4
  from sentence_transformers import SentenceTransformer
5
- from src.data_io import read_jsonl
6
 
7
  MODEL_PATH = Path("artifacts/models/finetuned_mpnet")
8
  INDEX_DIR = Path("artifacts/indexes/finetuned")
 
2
  import numpy as np
3
  import faiss
4
  from sentence_transformers import SentenceTransformer
5
+ from data_io import read_jsonl
6
 
7
  MODEL_PATH = Path("artifacts/models/finetuned_mpnet")
8
  INDEX_DIR = Path("artifacts/indexes/finetuned")
src/evaluate.py CHANGED
@@ -1,105 +1,258 @@
1
- import json
2
- from pathlib import Path
3
- import numpy as np
4
- import faiss
5
- from sentence_transformers import SentenceTransformer
6
- from src.data_io import load_pairs, read_jsonl
7
-
8
- def load_index(lang: str, alias: str):
9
- base = Path("artifacts/indexes") / alias
10
- idx_path = base / f"{lang}.faiss"
11
- meta_path = base / f"{lang}_meta.jsonl"
12
- index = faiss.read_index(str(idx_path))
13
- meta = read_jsonl(str(meta_path))
14
- pos_to_id = {int(x["pos"]): x["id"] for x in meta}
15
- return index, meta, pos_to_id
16
-
17
- def metrics_from_ranks(ranks, ks=(1, 3, 5, 10)):
18
- out = {}
19
- for k in ks:
20
- out[f"recall@{k}"] = float(np.mean([1.0 if r is not None and r < k else 0.0 for r in ranks]))
21
- rr = []
22
- dcg = []
23
- for r in ranks:
24
- if r is None:
25
- rr.append(0.0)
26
- dcg.append(0.0)
27
- else:
28
- rr.append(1.0 / (r + 1.0))
29
- dcg.append(1.0 / np.log2(r + 2.0))
30
- out["mrr@10"] = float(np.mean(rr))
31
- out["ndcg@10"] = float(np.mean(dcg))
32
- return out
33
-
34
- def eval_model(model_name: str, index_alias: str, test_path: str, top_k=10):
35
- model = SentenceTransformer(model_name)
36
-
37
- test = load_pairs(test_path)
38
- groups = {"ru": [x for x in test if x["lang"] == "ru"], "kz": [x for x in test if x["lang"] == "kz"]}
39
-
40
- results = {
41
- "model": model_name,
42
- "index_alias": index_alias,
43
- "test_path": test_path,
44
- "top_k": top_k,
45
- "by_lang": {},
46
- }
47
-
48
- all_ranks = []
49
-
50
- for lang, items in groups.items():
51
- if not items:
52
- results["by_lang"][lang] = {"count": 0}
53
- continue
54
-
55
- index, meta, pos_to_id = load_index(lang, index_alias)
56
-
57
- queries = [x["query"] for x in items]
58
- q_emb = model.encode(queries, batch_size=64, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True).astype(np.float32)
59
- scores, idxs = index.search(q_emb, top_k)
60
-
61
- ranks = []
62
- for i, x in enumerate(items):
63
- target = x["positive_id"]
64
- found_rank = None
65
- for r in range(top_k):
66
- did = pos_to_id.get(int(idxs[i, r]))
67
- if did == target:
68
- found_rank = r
69
- break
70
- ranks.append(found_rank)
71
-
72
- all_ranks.extend(ranks)
73
-
74
- results["by_lang"][lang] = {
75
- "count": len(items),
76
- **metrics_from_ranks(ranks, ks=(1, 3, 5, 10)),
77
- }
78
-
79
- results["overall"] = {
80
- "count": len(all_ranks),
81
- **metrics_from_ranks(all_ranks, ks=(1, 3, 5, 10)),
82
- }
83
- return results
84
-
85
- def main():
86
- test_path = "data/legal_assistant_test.jsonl"
87
-
88
- models = [
89
- ("mpnet_base", "paraphrase-multilingual-mpnet-base-v2"),
90
- ("labse", "sentence-transformers/LaBSE"),
91
- ]
92
-
93
- finetuned_dir = Path("artifacts/models/finetuned_mpnet")
94
- if finetuned_dir.exists():
95
- models.append(("finetuned", str(finetuned_dir)))
96
-
97
- out_dir = Path("artifacts/reports")
98
- out_dir.mkdir(parents=True, exist_ok=True)
99
-
100
- for alias, model_name in models:
101
- r = eval_model(model_name, alias, test_path, top_k=10)
102
- (out_dir / f"eval_{alias}.json").write_text(json.dumps(r, ensure_ascii=False, indent=2), encoding="utf-8")
103
-
104
- if __name__ == "__main__":
105
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import faiss
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ from src.data_io import load_pairs, read_jsonl
9
+
10
+
11
+ def load_index(lang: str, alias: str):
12
+ base = Path("artifacts/indexes") / alias
13
+ idx_path = base / f"{lang}.faiss"
14
+ meta_path = base / f"{lang}_meta.jsonl"
15
+ index = faiss.read_index(str(idx_path))
16
+ meta = read_jsonl(str(meta_path))
17
+ pos_to_id = {int(x["pos"]): x["id"] for x in meta}
18
+ return index, meta, pos_to_id
19
+
20
+
21
+ def _stats_from_values(values):
22
+ if not values:
23
+ return {
24
+ "mean": None,
25
+ "median": None,
26
+ "p10": None,
27
+ "p90": None,
28
+ }
29
+ arr = np.array(values, dtype=float)
30
+ return {
31
+ "mean": float(np.mean(arr)),
32
+ "median": float(np.median(arr)),
33
+ "p10": float(np.percentile(arr, 10)),
34
+ "p90": float(np.percentile(arr, 90)),
35
+ }
36
+
37
+
38
+ def metrics_from_ranks(ranks, ks=(1, 3, 5, 10)):
39
+ out = {}
40
+ for k in ks:
41
+ hits = [1.0 if r is not None and r < k else 0.0 for r in ranks]
42
+ hit_rate = float(np.mean(hits)) if ranks else 0.0
43
+ out[f"recall@{k}"] = hit_rate
44
+ out[f"hit@{k}"] = hit_rate
45
+ out[f"precision@{k}"] = float(np.mean([h / k for h in hits])) if ranks else 0.0
46
+
47
+ rr = []
48
+ dcg = []
49
+ for r in ranks:
50
+ if r is None:
51
+ rr.append(0.0)
52
+ dcg.append(0.0)
53
+ else:
54
+ rr.append(1.0 / (r + 1.0))
55
+ dcg.append(1.0 / np.log2(r + 2.0))
56
+ out["mrr@10"] = float(np.mean(rr)) if rr else 0.0
57
+ out["ndcg@10"] = float(np.mean(dcg)) if dcg else 0.0
58
+ out["not_found_rate"] = float(np.mean([1.0 if r is None else 0.0 for r in ranks])) if ranks else 0.0
59
+ return out
60
+
61
+
62
+ def eval_model(model_name: str, index_alias: str, test_path: str, top_k=10):
63
+ model = SentenceTransformer(model_name)
64
+
65
+ test = load_pairs(test_path)
66
+ groups = {
67
+ "ru": [x for x in test if x["lang"] == "ru"],
68
+ "kz": [x for x in test if x["lang"] == "kz"],
69
+ }
70
+
71
+ results = {
72
+ "model": model_name,
73
+ "index_alias": index_alias,
74
+ "test_path": test_path,
75
+ "top_k": top_k,
76
+ "by_lang": {},
77
+ }
78
+
79
+ all_ranks = []
80
+ all_top1_scores = []
81
+ all_top1_scores_tp = []
82
+ all_top1_scores_fp = []
83
+ all_margins = []
84
+ all_coverage_ids = set()
85
+ total_corpus_size = 0
86
+
87
+ for lang, items in groups.items():
88
+ if not items:
89
+ results["by_lang"][lang] = {"count": 0}
90
+ continue
91
+
92
+ index, meta, pos_to_id = load_index(lang, index_alias)
93
+ total_corpus_size += len(meta)
94
+
95
+ queries = [x["query"] for x in items]
96
+ q_emb = model.encode(
97
+ queries,
98
+ batch_size=64,
99
+ convert_to_numpy=True,
100
+ normalize_embeddings=True,
101
+ show_progress_bar=True,
102
+ ).astype(np.float32)
103
+ scores, idxs = index.search(q_emb, top_k)
104
+
105
+ ranks = []
106
+ top1_scores = []
107
+ top1_scores_tp = []
108
+ top1_scores_fp = []
109
+ margins = []
110
+ coverage_ids = set()
111
+ for i, x in enumerate(items):
112
+ target = x["positive_id"]
113
+ found_rank = None
114
+ top_scores = [float(s) for s in scores[i].tolist()]
115
+ for r in range(top_k):
116
+ pos = int(idxs[i, r])
117
+ did = pos_to_id.get(pos)
118
+ if did is None:
119
+ continue
120
+ coverage_ids.add(did)
121
+ if did == target:
122
+ found_rank = r
123
+ break
124
+ ranks.append(found_rank)
125
+
126
+ if top_scores:
127
+ top1 = top_scores[0]
128
+ top1_scores.append(top1)
129
+ if found_rank == 0:
130
+ top1_scores_tp.append(top1)
131
+ else:
132
+ top1_scores_fp.append(top1)
133
+
134
+ if len(top_scores) >= 2:
135
+ margins.append(top_scores[0] - top_scores[1])
136
+
137
+ all_ranks.extend(ranks)
138
+ all_top1_scores.extend(top1_scores)
139
+ all_top1_scores_tp.extend(top1_scores_tp)
140
+ all_top1_scores_fp.extend(top1_scores_fp)
141
+ all_margins.extend(margins)
142
+ all_coverage_ids.update(coverage_ids)
143
+
144
+ found_ranks_1based = [r + 1 for r in ranks if r is not None]
145
+ rank_stats = _stats_from_values(found_ranks_1based)
146
+ rank_stats.update(
147
+ {
148
+ "found_count": len(found_ranks_1based),
149
+ "not_found_count": len(ranks) - len(found_ranks_1based),
150
+ "not_found_rate": float(np.mean([1.0 if r is None else 0.0 for r in ranks])) if ranks else 0.0,
151
+ }
152
+ )
153
+
154
+ score_stats = _stats_from_values(top1_scores)
155
+ margin_stats = _stats_from_values(margins)
156
+ coverage = {
157
+ "unique_ids": len(coverage_ids),
158
+ "corpus_size": len(meta),
159
+ "coverage_ratio": float(len(coverage_ids) / len(meta)) if meta else 0.0,
160
+ }
161
+
162
+ results["by_lang"][lang] = {
163
+ "count": len(items),
164
+ **metrics_from_ranks(ranks, ks=(1, 3, 5, 10)),
165
+ "rank_stats": {
166
+ "mean_rank": rank_stats["mean"],
167
+ "median_rank": rank_stats["median"],
168
+ "p10_rank": rank_stats["p10"],
169
+ "p90_rank": rank_stats["p90"],
170
+ "found_count": rank_stats["found_count"],
171
+ "not_found_count": rank_stats["not_found_count"],
172
+ "not_found_rate": rank_stats["not_found_rate"],
173
+ },
174
+ "score_stats": {
175
+ "top1_score": score_stats,
176
+ "margin_top1_top2": margin_stats,
177
+ },
178
+ "coverage": coverage,
179
+ "distributions": {
180
+ "ranks": [r if r is not None else -1 for r in ranks],
181
+ "top1_scores": top1_scores,
182
+ "top1_scores_tp": top1_scores_tp,
183
+ "top1_scores_fp": top1_scores_fp,
184
+ "margins": margins,
185
+ },
186
+ }
187
+
188
+ overall_found_ranks_1based = [r + 1 for r in all_ranks if r is not None]
189
+ overall_rank_stats = _stats_from_values(overall_found_ranks_1based)
190
+ overall_rank_stats.update(
191
+ {
192
+ "found_count": len(overall_found_ranks_1based),
193
+ "not_found_count": len(all_ranks) - len(overall_found_ranks_1based),
194
+ "not_found_rate": float(np.mean([1.0 if r is None else 0.0 for r in all_ranks])) if all_ranks else 0.0,
195
+ }
196
+ )
197
+
198
+ overall_score_stats = _stats_from_values(all_top1_scores)
199
+ overall_margin_stats = _stats_from_values(all_margins)
200
+ overall_coverage = {
201
+ "unique_ids": len(all_coverage_ids),
202
+ "corpus_size": total_corpus_size,
203
+ "coverage_ratio": float(len(all_coverage_ids) / total_corpus_size) if total_corpus_size else 0.0,
204
+ }
205
+
206
+ results["overall"] = {
207
+ "count": len(all_ranks),
208
+ **metrics_from_ranks(all_ranks, ks=(1, 3, 5, 10)),
209
+ "rank_stats": {
210
+ "mean_rank": overall_rank_stats["mean"],
211
+ "median_rank": overall_rank_stats["median"],
212
+ "p10_rank": overall_rank_stats["p10"],
213
+ "p90_rank": overall_rank_stats["p90"],
214
+ "found_count": overall_rank_stats["found_count"],
215
+ "not_found_count": overall_rank_stats["not_found_count"],
216
+ "not_found_rate": overall_rank_stats["not_found_rate"],
217
+ },
218
+ "score_stats": {
219
+ "top1_score": overall_score_stats,
220
+ "margin_top1_top2": overall_margin_stats,
221
+ },
222
+ "coverage": overall_coverage,
223
+ "distributions": {
224
+ "ranks": [r if r is not None else -1 for r in all_ranks],
225
+ "top1_scores": all_top1_scores,
226
+ "top1_scores_tp": all_top1_scores_tp,
227
+ "top1_scores_fp": all_top1_scores_fp,
228
+ "margins": all_margins,
229
+ },
230
+ }
231
+ return results
232
+
233
+
234
+ def main():
235
+ test_path = "data/legal_assistant_test.jsonl"
236
+
237
+ models = [
238
+ ("mpnet_base", "paraphrase-multilingual-mpnet-base-v2"),
239
+ ("labse", "sentence-transformers/LaBSE"),
240
+ ]
241
+
242
+ finetuned_dir = Path("artifacts/models/finetuned_mpnet")
243
+ if finetuned_dir.exists():
244
+ models.append(("finetuned", str(finetuned_dir)))
245
+
246
+ out_dir = Path("artifacts/reports")
247
+ out_dir.mkdir(parents=True, exist_ok=True)
248
+
249
+ for alias, model_name in models:
250
+ r = eval_model(model_name, alias, test_path, top_k=10)
251
+ (out_dir / f"eval_{alias}.json").write_text(
252
+ json.dumps(r, ensure_ascii=False, indent=2),
253
+ encoding="utf-8",
254
+ )
255
+
256
+
257
+ if __name__ == "__main__":
258
+ main()
src/plot_eval.py CHANGED
@@ -1,196 +1,522 @@
1
- import json
2
- from pathlib import Path
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
-
6
- def read_json(path):
7
- return json.loads(Path(path).read_text(encoding="utf-8"))
8
-
9
- def pick_models(files):
10
- items = []
11
- for p in files:
12
- try:
13
- j = read_json(p)
14
- items.append((Path(p).stem, j))
15
- except Exception:
16
- pass
17
- return items
18
-
19
- def metric_value(obj, scope, lang, metric):
20
- if scope == "overall":
21
- return obj.get("overall", {}).get(metric, None)
22
- if scope == "by_lang":
23
- return obj.get("by_lang", {}).get(lang, {}).get(metric, None)
24
- return None
25
-
26
- def save_recall_plot(models, scope, lang, out_path):
27
- ks = [1, 3, 5, 10]
28
- x = np.arange(len(ks))
29
- width = 0.8 / max(1, len(models))
30
-
31
- plt.figure()
32
- for i, (name, obj) in enumerate(models):
33
- vals = []
34
- for k in ks:
35
- v = metric_value(obj, scope, lang, f"recall@{k}")
36
- vals.append(0.0 if v is None else float(v))
37
- plt.bar(x + (i - (len(models) - 1) / 2) * width, vals, width=width, label=obj.get("model", name))
38
-
39
- plt.xticks(x, [f"@{k}" for k in ks])
40
- title = "Recall@k"
41
- if scope == "overall":
42
- plt.title(f"{title} (overall)")
43
- else:
44
- plt.title(f"{title} ({lang})")
45
- plt.ylabel("score")
46
- ymax = max([0.0] + [max([metric_value(o, scope, lang, f"recall@{k}") or 0.0 for k in ks]) for _, o in models])
47
- plt.ylim(0, min(1.0, max(0.05, ymax * 1.2)))
48
- plt.legend()
49
- Path(out_path).parent.mkdir(parents=True, exist_ok=True)
50
- plt.tight_layout()
51
- plt.savefig(out_path, dpi=180)
52
- plt.close()
53
-
54
- def save_rank_metrics_plot(models, scope, lang, out_path):
55
- metrics = ["mrr@10", "ndcg@10"]
56
- x = np.arange(len(metrics))
57
- width = 0.8 / max(1, len(models))
58
-
59
- plt.figure()
60
- for i, (name, obj) in enumerate(models):
61
- vals = []
62
- for m in metrics:
63
- v = metric_value(obj, scope, lang, m)
64
- vals.append(0.0 if v is None else float(v))
65
- plt.bar(x + (i - (len(models) - 1) / 2) * width, vals, width=width, label=obj.get("model", name))
66
-
67
- plt.xticks(x, metrics)
68
- title = "Ranking metrics"
69
- if scope == "overall":
70
- plt.title(f"{title} (overall)")
71
- else:
72
- plt.title(f"{title} ({lang})")
73
- plt.ylabel("score")
74
- ymax = max([0.0] + [max([metric_value(o, scope, lang, m) or 0.0 for m in metrics]) for _, o in models])
75
- plt.ylim(0, min(1.0, max(0.05, ymax * 1.2)))
76
- plt.legend()
77
- Path(out_path).parent.mkdir(parents=True, exist_ok=True)
78
- plt.tight_layout()
79
- plt.savefig(out_path, dpi=180)
80
- plt.close()
81
-
82
- def save_recall_curve_plot(models, scope, lang, out_path):
83
- ks = [1, 3, 5, 10]
84
- xs = np.array(ks, dtype=float)
85
-
86
- plt.figure()
87
- for name, obj in models:
88
- ys = []
89
- for k in ks:
90
- v = metric_value(obj, scope, lang, f"recall@{k}")
91
- ys.append(0.0 if v is None else float(v))
92
- plt.plot(xs, ys, marker="o", label=obj.get("model", name))
93
-
94
- plt.xticks(xs, [f"@{k}" for k in ks])
95
- title = "Recall@k vs k"
96
- if scope == "overall":
97
- plt.title(f"{title} (overall)")
98
- else:
99
- plt.title(f"{title} ({lang})")
100
- plt.xlabel("k")
101
- plt.ylabel("recall")
102
- ymax = max([0.0] + [max([metric_value(o, scope, lang, f"recall@{k}") or 0.0 for k in ks]) for _, o in models])
103
- plt.ylim(0, min(1.0, max(0.05, ymax * 1.2)))
104
- plt.legend()
105
- Path(out_path).parent.mkdir(parents=True, exist_ok=True)
106
- plt.tight_layout()
107
- plt.savefig(out_path, dpi=180)
108
- plt.close()
109
-
110
- def model_label_key(obj, name):
111
- s = str(obj.get("model", name)).lower()
112
- if "labse" in s:
113
- return "labse"
114
- if "finetuned" in s or "artifacts" in s:
115
- return "finetuned"
116
- if "paraphrase-multilingual-mpnet-base-v2" in s:
117
- return "base"
118
- if "mpnet" in s:
119
- return "base"
120
- return name.lower()
121
-
122
- def select_model(models, key):
123
- for name, obj in models:
124
- if model_label_key(obj, name) == key:
125
- return (name, obj)
126
- return None
127
-
128
- def save_relative_improvement_plot(models, scope, lang, out_path):
129
- fin = select_model(models, "finetuned")
130
- base = select_model(models, "base")
131
- if fin is None or base is None:
132
- return
133
-
134
- metrics = ["recall@1", "recall@3", "recall@5", "recall@10", "mrr@10", "ndcg@10"]
135
- labels = ["R@1", "R@3", "R@5", "R@10", "MRR@10", "nDCG@10"]
136
-
137
- fin_obj = fin[1]
138
- base_obj = base[1]
139
-
140
- vals = []
141
- for m in metrics:
142
- fv = metric_value(fin_obj, scope, lang, m)
143
- bv = metric_value(base_obj, scope, lang, m)
144
- fv = 0.0 if fv is None else float(fv)
145
- bv = 0.0 if bv is None else float(bv)
146
- if bv <= 0:
147
- vals.append(np.nan)
148
- else:
149
- vals.append((fv - bv) / bv * 100.0)
150
-
151
- x = np.arange(len(metrics))
152
- plt.figure()
153
- plt.bar(x, vals)
154
- plt.xticks(x, labels)
155
- title = "Relative improvement vs base (%)"
156
- if scope == "overall":
157
- plt.title(f"{title} (overall)")
158
- else:
159
- plt.title(f"{title} ({lang})")
160
- plt.ylabel("%")
161
- plt.axhline(0.0)
162
- Path(out_path).parent.mkdir(parents=True, exist_ok=True)
163
- plt.tight_layout()
164
- plt.savefig(out_path, dpi=180)
165
- plt.close()
166
-
167
- def main():
168
- reports_dir = Path("artifacts/reports")
169
- files = sorted([str(p) for p in reports_dir.glob("eval_*.json")])
170
- models = pick_models(files)
171
-
172
- if not models:
173
- raise SystemExit("No eval_*.json found in artifacts/reports")
174
-
175
- fig_dir = reports_dir / "figures"
176
- fig_dir.mkdir(parents=True, exist_ok=True)
177
-
178
- save_recall_plot(models, "overall", None, fig_dir / "recall_overall.png")
179
- save_rank_metrics_plot(models, "overall", None, fig_dir / "rank_metrics_overall.png")
180
- save_recall_curve_plot(models, "overall", None, fig_dir / "recall_curve_overall.png")
181
- save_relative_improvement_plot(models, "overall", None, fig_dir / "relative_improvement_overall.png")
182
-
183
- for lang in ["ru", "kz"]:
184
- save_recall_plot(models, "by_lang", lang, fig_dir / f"recall_{lang}.png")
185
- save_rank_metrics_plot(models, "by_lang", lang, fig_dir / f"rank_metrics_{lang}.png")
186
- save_recall_curve_plot(models, "by_lang", lang, fig_dir / f"recall_curve_{lang}.png")
187
- save_relative_improvement_plot(models, "by_lang", lang, fig_dir / f"relative_improvement_{lang}.png")
188
-
189
- summary = {
190
- "loaded_reports": [Path(f).name for f in files],
191
- "figures": [p.name for p in sorted(fig_dir.glob("*.png"))],
192
- }
193
- (reports_dir / "figures_summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
194
-
195
- if __name__ == "__main__":
196
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+
7
+
8
+ def read_json(path):
9
+ return json.loads(Path(path).read_text(encoding="utf-8"))
10
+
11
+
12
+ def pick_models(files):
13
+ items = []
14
+ for p in files:
15
+ try:
16
+ j = read_json(p)
17
+ items.append((Path(p).stem, j))
18
+ except Exception:
19
+ pass
20
+ return items
21
+
22
+
23
+ def metric_value(obj, scope, lang, metric):
24
+ if scope == "overall":
25
+ return obj.get("overall", {}).get(metric, None)
26
+ if scope == "by_lang":
27
+ return obj.get("by_lang", {}).get(lang, {}).get(metric, None)
28
+ return None
29
+
30
+
31
+ def section(obj, scope, lang):
32
+ if scope == "overall":
33
+ return obj.get("overall", {})
34
+ if scope == "by_lang":
35
+ return obj.get("by_lang", {}).get(lang, {})
36
+ return {}
37
+
38
+
39
+ def rank_stat_value(obj, scope, lang, key):
40
+ return section(obj, scope, lang).get("rank_stats", {}).get(key, None)
41
+
42
+
43
+ def score_stat_value(obj, scope, lang, group, key):
44
+ return section(obj, scope, lang).get("score_stats", {}).get(group, {}).get(key, None)
45
+
46
+
47
+ def coverage_value(obj, scope, lang, key):
48
+ return section(obj, scope, lang).get("coverage", {}).get(key, None)
49
+
50
+
51
+ def distribution_value(obj, scope, lang, key):
52
+ return section(obj, scope, lang).get("distributions", {}).get(key, [])
53
+
54
+
55
+ def save_recall_plot(models, scope, lang, out_path):
56
+ ks = [1, 3, 5, 10]
57
+ x = np.arange(len(ks))
58
+ width = 0.8 / max(1, len(models))
59
+
60
+ plt.figure()
61
+ for i, (name, obj) in enumerate(models):
62
+ vals = []
63
+ for k in ks:
64
+ v = metric_value(obj, scope, lang, f"recall@{k}")
65
+ vals.append(0.0 if v is None else float(v))
66
+ plt.bar(
67
+ x + (i - (len(models) - 1) / 2) * width,
68
+ vals,
69
+ width=width,
70
+ label=obj.get("model", name),
71
+ )
72
+
73
+ plt.xticks(x, [f"@{k}" for k in ks])
74
+ title = "Recall@k"
75
+ if scope == "overall":
76
+ plt.title(f"{title} (overall)")
77
+ else:
78
+ plt.title(f"{title} ({lang})")
79
+ plt.ylabel("score")
80
+ ymax = max(
81
+ [0.0]
82
+ + [
83
+ max(
84
+ [
85
+ metric_value(o, scope, lang, f"recall@{k}") or 0.0
86
+ for k in ks
87
+ ]
88
+ )
89
+ for _, o in models
90
+ ]
91
+ )
92
+ plt.ylim(0, min(1.0, max(0.05, ymax * 1.2)))
93
+ plt.legend()
94
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
95
+ plt.tight_layout()
96
+ plt.savefig(out_path, dpi=180)
97
+ plt.close()
98
+
99
+
100
+ def save_rank_metrics_plot(models, scope, lang, out_path):
101
+ metrics = ["mrr@10", "ndcg@10"]
102
+ x = np.arange(len(metrics))
103
+ width = 0.8 / max(1, len(models))
104
+
105
+ plt.figure()
106
+ for i, (name, obj) in enumerate(models):
107
+ vals = []
108
+ for m in metrics:
109
+ v = metric_value(obj, scope, lang, m)
110
+ vals.append(0.0 if v is None else float(v))
111
+ plt.bar(
112
+ x + (i - (len(models) - 1) / 2) * width,
113
+ vals,
114
+ width=width,
115
+ label=obj.get("model", name),
116
+ )
117
+
118
+ plt.xticks(x, metrics)
119
+ title = "Ranking metrics"
120
+ if scope == "overall":
121
+ plt.title(f"{title} (overall)")
122
+ else:
123
+ plt.title(f"{title} ({lang})")
124
+ plt.ylabel("score")
125
+ ymax = max(
126
+ [0.0]
127
+ + [
128
+ max([metric_value(o, scope, lang, m) or 0.0 for m in metrics])
129
+ for _, o in models
130
+ ]
131
+ )
132
+ plt.ylim(0, min(1.0, max(0.05, ymax * 1.2)))
133
+ plt.legend()
134
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
135
+ plt.tight_layout()
136
+ plt.savefig(out_path, dpi=180)
137
+ plt.close()
138
+
139
+
140
+ def save_precision_plot(models, scope, lang, out_path):
141
+ ks = [1, 3, 5, 10]
142
+ x = np.arange(len(ks))
143
+ width = 0.8 / max(1, len(models))
144
+
145
+ plt.figure()
146
+ any_data = False
147
+ for i, (name, obj) in enumerate(models):
148
+ vals = []
149
+ for k in ks:
150
+ v = metric_value(obj, scope, lang, f"precision@{k}")
151
+ if v is not None:
152
+ any_data = True
153
+ vals.append(0.0 if v is None else float(v))
154
+ plt.bar(
155
+ x + (i - (len(models) - 1) / 2) * width,
156
+ vals,
157
+ width=width,
158
+ label=obj.get("model", name),
159
+ )
160
+
161
+ if not any_data:
162
+ plt.close()
163
+ return
164
+
165
+ plt.xticks(x, [f"@{k}" for k in ks])
166
+ title = "Precision@k (single-positive)"
167
+ if scope == "overall":
168
+ plt.title(f"{title} (overall)")
169
+ else:
170
+ plt.title(f"{title} ({lang})")
171
+ plt.ylabel("score")
172
+ ymax = max(
173
+ [0.0]
174
+ + [
175
+ max(
176
+ [
177
+ metric_value(o, scope, lang, f"precision@{k}") or 0.0
178
+ for k in ks
179
+ ]
180
+ )
181
+ for _, o in models
182
+ ]
183
+ )
184
+ plt.ylim(0, min(1.0, max(0.05, ymax * 1.2)))
185
+ plt.legend()
186
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
187
+ plt.tight_layout()
188
+ plt.savefig(out_path, dpi=180)
189
+ plt.close()
190
+
191
+
192
+ def save_recall_curve_plot(models, scope, lang, out_path):
193
+ ks = [1, 3, 5, 10]
194
+ xs = np.array(ks, dtype=float)
195
+
196
+ plt.figure()
197
+ for name, obj in models:
198
+ ys = []
199
+ for k in ks:
200
+ v = metric_value(obj, scope, lang, f"recall@{k}")
201
+ ys.append(0.0 if v is None else float(v))
202
+ plt.plot(xs, ys, marker="o", label=obj.get("model", name))
203
+
204
+ plt.xticks(xs, [f"@{k}" for k in ks])
205
+ title = "Recall@k vs k"
206
+ if scope == "overall":
207
+ plt.title(f"{title} (overall)")
208
+ else:
209
+ plt.title(f"{title} ({lang})")
210
+ plt.xlabel("k")
211
+ plt.ylabel("recall")
212
+ ymax = max(
213
+ [0.0]
214
+ + [
215
+ max(
216
+ [
217
+ metric_value(o, scope, lang, f"recall@{k}") or 0.0
218
+ for k in ks
219
+ ]
220
+ )
221
+ for _, o in models
222
+ ]
223
+ )
224
+ plt.ylim(0, min(1.0, max(0.05, ymax * 1.2)))
225
+ plt.legend()
226
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
227
+ plt.tight_layout()
228
+ plt.savefig(out_path, dpi=180)
229
+ plt.close()
230
+
231
+
232
+ def save_rank_stats_plot(models, scope, lang, out_path):
233
+ metrics = [("mean_rank", "Mean"), ("median_rank", "Median"), ("p90_rank", "P90")]
234
+ x = np.arange(len(metrics))
235
+ width = 0.8 / max(1, len(models))
236
+
237
+ plt.figure()
238
+ any_data = False
239
+ for i, (name, obj) in enumerate(models):
240
+ vals = []
241
+ for key, _ in metrics:
242
+ v = rank_stat_value(obj, scope, lang, key)
243
+ if v is not None:
244
+ any_data = True
245
+ vals.append(np.nan if v is None else float(v))
246
+ plt.bar(
247
+ x + (i - (len(models) - 1) / 2) * width,
248
+ vals,
249
+ width=width,
250
+ label=obj.get("model", name),
251
+ )
252
+
253
+ if not any_data:
254
+ plt.close()
255
+ return
256
+
257
+ plt.xticks(x, [m[1] for m in metrics])
258
+ title = "Rank stats (1-based)"
259
+ if scope == "overall":
260
+ plt.title(f"{title} (overall)")
261
+ else:
262
+ plt.title(f"{title} ({lang})")
263
+ plt.ylabel("rank")
264
+ plt.legend()
265
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
266
+ plt.tight_layout()
267
+ plt.savefig(out_path, dpi=180)
268
+ plt.close()
269
+
270
+
271
+ def save_rank_distribution_plot(models, scope, lang, out_path):
272
+ top_k = None
273
+ for _, obj in models:
274
+ if "top_k" in obj:
275
+ top_k = int(obj["top_k"])
276
+ break
277
+ if top_k is None:
278
+ return
279
+
280
+ x = np.arange(top_k + 1)
281
+ width = 0.8 / max(1, len(models))
282
+
283
+ plt.figure()
284
+ any_data = False
285
+ for i, (name, obj) in enumerate(models):
286
+ ranks = distribution_value(obj, scope, lang, "ranks")
287
+ if not ranks:
288
+ continue
289
+ any_data = True
290
+ buckets = [0] * (top_k + 1)
291
+ for r in ranks:
292
+ if r is None or r < 0 or r >= top_k:
293
+ buckets[-1] += 1
294
+ else:
295
+ buckets[int(r)] += 1
296
+ total = max(1, len(ranks))
297
+ vals = [b / total for b in buckets]
298
+ plt.bar(
299
+ x + (i - (len(models) - 1) / 2) * width,
300
+ vals,
301
+ width=width,
302
+ label=obj.get("model", name),
303
+ )
304
+
305
+ if not any_data:
306
+ plt.close()
307
+ return
308
+
309
+ labels = [str(i + 1) for i in range(top_k)] + ["NF"]
310
+ plt.xticks(x, labels)
311
+ title = "Rank distribution"
312
+ if scope == "overall":
313
+ plt.title(f"{title} (overall)")
314
+ else:
315
+ plt.title(f"{title} ({lang})")
316
+ plt.ylabel("share of queries")
317
+ plt.legend()
318
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
319
+ plt.tight_layout()
320
+ plt.savefig(out_path, dpi=180)
321
+ plt.close()
322
+
323
+
324
+ def save_margin_boxplot(models, scope, lang, out_path):
325
+ data = []
326
+ labels = []
327
+ for name, obj in models:
328
+ margins = distribution_value(obj, scope, lang, "margins")
329
+ if margins:
330
+ data.append(margins)
331
+ labels.append(obj.get("model", name))
332
+
333
+ if not data:
334
+ return
335
+
336
+ plt.figure()
337
+ plt.boxplot(data, labels=labels, showfliers=False)
338
+ title = "Score margin (top1 - top2)"
339
+ if scope == "overall":
340
+ plt.title(f"{title} (overall)")
341
+ else:
342
+ plt.title(f"{title} ({lang})")
343
+ plt.ylabel("margin")
344
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
345
+ plt.tight_layout()
346
+ plt.savefig(out_path, dpi=180)
347
+ plt.close()
348
+
349
+
350
+ def save_coverage_plot(models, scope, lang, out_path):
351
+ vals = []
352
+ labels = []
353
+ for name, obj in models:
354
+ v = coverage_value(obj, scope, lang, "coverage_ratio")
355
+ if v is not None:
356
+ vals.append(float(v))
357
+ labels.append(obj.get("model", name))
358
+
359
+ if not vals:
360
+ return
361
+
362
+ x = np.arange(len(vals))
363
+ plt.figure()
364
+ plt.bar(x, vals)
365
+ plt.xticks(x, labels, rotation=15, ha="right")
366
+ title = "Coverage ratio (unique docs / corpus)"
367
+ if scope == "overall":
368
+ plt.title(f"{title} (overall)")
369
+ else:
370
+ plt.title(f"{title} ({lang})")
371
+ plt.ylabel("ratio")
372
+ plt.ylim(0, 1.0)
373
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
374
+ plt.tight_layout()
375
+ plt.savefig(out_path, dpi=180)
376
+ plt.close()
377
+
378
+
379
+ def save_top1_score_hist(models, scope, lang, out_dir):
380
+ for name, obj in models:
381
+ tp = distribution_value(obj, scope, lang, "top1_scores_tp")
382
+ fp = distribution_value(obj, scope, lang, "top1_scores_fp")
383
+ if not tp and not fp:
384
+ continue
385
+ plt.figure()
386
+ if tp:
387
+ plt.hist(tp, bins=20, alpha=0.6, label="top-1 is positive")
388
+ if fp:
389
+ plt.hist(fp, bins=20, alpha=0.6, label="top-1 is not positive")
390
+ title = "Top-1 score distribution"
391
+ label = obj.get("model", name)
392
+ if scope == "overall":
393
+ plt.title(f"{title} ({label}, overall)")
394
+ else:
395
+ plt.title(f"{title} ({label}, {lang})")
396
+ plt.xlabel("similarity score")
397
+ plt.ylabel("count")
398
+ plt.legend()
399
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
400
+ out_path = (
401
+ Path(out_dir)
402
+ / f"top1_score_tp_fp_{model_label_key(obj, name)}_{scope if scope else 'overall'}{'' if lang is None else '_' + lang}.png"
403
+ )
404
+ plt.tight_layout()
405
+ plt.savefig(out_path, dpi=180)
406
+ plt.close()
407
+
408
+
409
+ def model_label_key(obj, name):
410
+ s = str(obj.get("model", name)).lower()
411
+ if "labse" in s:
412
+ return "labse"
413
+ if "finetuned" in s or "artifacts" in s:
414
+ return "finetuned"
415
+ if "paraphrase-multilingual-mpnet-base-v2" in s:
416
+ return "base"
417
+ if "mpnet" in s:
418
+ return "base"
419
+ return name.lower()
420
+
421
+
422
+ def select_model(models, key):
423
+ for name, obj in models:
424
+ if model_label_key(obj, name) == key:
425
+ return (name, obj)
426
+ return None
427
+
428
+
429
+ def save_relative_improvement_plot(models, scope, lang, out_path):
430
+ fin = select_model(models, "finetuned")
431
+ base = select_model(models, "base")
432
+ if fin is None or base is None:
433
+ return
434
+
435
+ metrics = ["recall@1", "recall@3", "recall@5", "recall@10", "mrr@10", "ndcg@10"]
436
+ labels = ["R@1", "R@3", "R@5", "R@10", "MRR@10", "nDCG@10"]
437
+
438
+ fin_obj = fin[1]
439
+ base_obj = base[1]
440
+
441
+ vals = []
442
+ for m in metrics:
443
+ fv = metric_value(fin_obj, scope, lang, m)
444
+ bv = metric_value(base_obj, scope, lang, m)
445
+ fv = 0.0 if fv is None else float(fv)
446
+ bv = 0.0 if bv is None else float(bv)
447
+ if bv <= 0:
448
+ vals.append(np.nan)
449
+ else:
450
+ vals.append((fv - bv) / bv * 100.0)
451
+
452
+ x = np.arange(len(metrics))
453
+ plt.figure()
454
+ plt.bar(x, vals)
455
+ plt.xticks(x, labels)
456
+ title = "Relative improvement vs base (%)"
457
+ if scope == "overall":
458
+ plt.title(f"{title} (overall)")
459
+ else:
460
+ plt.title(f"{title} ({lang})")
461
+ plt.ylabel("%")
462
+ plt.axhline(0.0)
463
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
464
+ plt.tight_layout()
465
+ plt.savefig(out_path, dpi=180)
466
+ plt.close()
467
+
468
+
469
+ def main():
470
+ reports_dir = Path("artifacts/reports")
471
+ files = sorted([str(p) for p in reports_dir.glob("eval_*.json")])
472
+ models = pick_models(files)
473
+
474
+ if not models:
475
+ raise SystemExit("No eval_*.json found in artifacts/reports")
476
+
477
+ fig_dir = reports_dir / "figures"
478
+ fig_dir.mkdir(parents=True, exist_ok=True)
479
+
480
+ save_recall_plot(models, "overall", None, fig_dir / "recall_overall.png")
481
+ save_rank_metrics_plot(models, "overall", None, fig_dir / "rank_metrics_overall.png")
482
+ save_recall_curve_plot(models, "overall", None, fig_dir / "recall_curve_overall.png")
483
+ save_relative_improvement_plot(models, "overall", None, fig_dir / "relative_improvement_overall.png")
484
+ save_precision_plot(models, "overall", None, fig_dir / "precision_overall.png")
485
+ save_rank_stats_plot(models, "overall", None, fig_dir / "rank_stats_overall.png")
486
+ save_rank_distribution_plot(
487
+ models, "overall", None, fig_dir / "rank_distribution_overall.png"
488
+ )
489
+ save_margin_boxplot(models, "overall", None, fig_dir / "score_margin_overall.png")
490
+ save_coverage_plot(models, "overall", None, fig_dir / "coverage_overall.png")
491
+ save_top1_score_hist(models, "overall", None, fig_dir)
492
+
493
+ for lang in ["ru", "kz"]:
494
+ save_recall_plot(models, "by_lang", lang, fig_dir / f"recall_{lang}.png")
495
+ save_rank_metrics_plot(
496
+ models, "by_lang", lang, fig_dir / f"rank_metrics_{lang}.png"
497
+ )
498
+ save_recall_curve_plot(
499
+ models, "by_lang", lang, fig_dir / f"recall_curve_{lang}.png"
500
+ )
501
+ save_relative_improvement_plot(
502
+ models, "by_lang", lang, fig_dir / f"relative_improvement_{lang}.png"
503
+ )
504
+ save_precision_plot(models, "by_lang", lang, fig_dir / f"precision_{lang}.png")
505
+ save_rank_stats_plot(models, "by_lang", lang, fig_dir / f"rank_stats_{lang}.png")
506
+ save_rank_distribution_plot(
507
+ models, "by_lang", lang, fig_dir / f"rank_distribution_{lang}.png"
508
+ )
509
+ save_coverage_plot(models, "by_lang", lang, fig_dir / f"coverage_{lang}.png")
510
+
511
+ summary = {
512
+ "loaded_reports": [Path(f).name for f in files],
513
+ "figures": [p.name for p in sorted(fig_dir.glob("*.png"))],
514
+ }
515
+ (reports_dir / "figures_summary.json").write_text(
516
+ json.dumps(summary, ensure_ascii=False, indent=2),
517
+ encoding="utf-8",
518
+ )
519
+
520
+
521
+ if __name__ == "__main__":
522
+ main()