lsdf commited on
Commit
f8eb22b
·
1 Parent(s): d2ba52e

Implement staged multi-objective optimization flow after BERT.

Browse files

Add explicit stage pipeline (BERT -> BM25 -> Semantic -> N-gram -> Title), stage progress/plateau transitions, and expose candidate phrase strategy in debug table for clearer diagnostics.

Made-with: Cursor

Files changed (2) hide show
  1. optimizer.py +127 -8
  2. templates/index.html +4 -2
optimizer.py CHANGED
@@ -22,8 +22,10 @@ STOP_WORDS = {
22
 
23
  BERT_TARGET_THRESHOLD = 0.7
24
  BERT_GOAL_DELTA_MIN = 0.005
 
25
  SEMANTIC_GAP_TOLERANCE_PCT = 0.15
26
  SEMANTIC_GAP_MIN_ABS = 3.0
 
27
 
28
 
29
  def _tokenize(text: str) -> List[str]:
@@ -338,18 +340,25 @@ def _compute_metrics(analysis: Dict[str, Any], semantic: Dict[str, Any], keyword
338
  }
339
 
340
 
341
- def _choose_optimization_goal(analysis: Dict[str, Any], semantic: Dict[str, Any], keywords: List[str], language: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
342
  bert_details = analysis.get("bert_analysis", {}).get("detailed", []) or []
343
  low_bert = [x for x in bert_details if float(x.get("my_max_score", 0)) < BERT_TARGET_THRESHOLD]
344
  if low_bert:
345
  worst = sorted(low_bert, key=lambda x: float(x.get("my_max_score", 0)))[0]
346
  focus_terms = _filter_stopwords(_tokenize(worst.get("phrase", "")), language)[:4]
347
- return {"type": "bert", "label": str(worst.get("phrase", "")), "focus_terms": focus_terms, "avoid_terms": []}
348
 
349
  bm25_remove = [x for x in (analysis.get("bm25_recommendations") or []) if x.get("action") == "remove"]
350
  if len(bm25_remove) >= 4:
351
  spam_terms = [str(x.get("word", "")) for x in sorted(bm25_remove, key=lambda r: int(r.get("count", 0)), reverse=True)[:4]]
352
- return {"type": "bm25", "label": "reduce spam", "focus_terms": [], "avoid_terms": spam_terms}
353
 
354
  # Semantic keyword gaps
355
  lang_stop = STOP_WORDS.get(language, STOP_WORDS["en"])
@@ -373,7 +382,7 @@ def _choose_optimization_goal(analysis: Dict[str, Any], semantic: Dict[str, Any]
373
  candidate_rows.append((term, gap))
374
  if candidate_rows:
375
  top_term = sorted(candidate_rows, key=lambda x: x[1], reverse=True)[0][0]
376
- return {"type": "semantic", "label": top_term, "focus_terms": [top_term], "avoid_terms": []}
377
 
378
  # Fallback: ngram add signal
379
  for bucket_name in ("unigrams", "bigrams"):
@@ -382,7 +391,28 @@ def _choose_optimization_goal(analysis: Dict[str, Any], semantic: Dict[str, Any]
382
  target = float(item.get("target_count", 0))
383
  comp_avg = float(item.get("competitor_avg", 0))
384
  if (target == 0 and comp_avg > 0) or (target > 0 and comp_avg >= target * 2):
385
- return {"type": "ngram", "label": str(item.get("ngram", "")), "focus_terms": _tokenize(str(item.get("ngram", "")))[:3], "avoid_terms": []}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  return {"type": "none", "label": "no-op", "focus_terms": [], "avoid_terms": []}
388
 
@@ -906,6 +936,48 @@ def _safe_delta(prev_metrics: Dict[str, Any], next_metrics: Dict[str, Any], key:
906
  return 0.0
907
 
908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909
  def _candidate_utility(
910
  *,
911
  prev_metrics: Dict[str, Any],
@@ -1036,13 +1108,38 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1036
  goal_attempt_cursor: Dict[str, int] = {}
1037
  attempted_spans = set()
1038
  queued_candidates: List[Dict[str, Any]] = []
 
 
1039
 
1040
  for step in range(max_iterations):
1041
- goal = _choose_optimization_goal(current_analysis, current_semantic, keywords, language)
1042
- if goal["type"] == "none":
1043
- logs.append({"step": step + 1, "status": "stopped", "reason": "No optimization goals left."})
 
 
1044
  break
1045
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1046
  sentences = _split_sentences(current_text)
1047
  if not sentences:
1048
  logs.append({"step": step + 1, "status": "stopped", "reason": "No sentences available for editing."})
@@ -1347,6 +1444,10 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1347
  current_analysis = best_local["analysis"]
1348
  current_semantic = best_local["semantic"]
1349
  current_metrics = best_local["metrics"]
 
 
 
 
1350
  applied_changes += 1
1351
  queued_candidates = []
1352
 
@@ -1354,6 +1455,7 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1354
  {
1355
  "step": step + 1,
1356
  "status": "applied_local_progress",
 
1357
  "goal": goal,
1358
  "cascade_level": cascade_level,
1359
  "operation": best_local.get("operation"),
@@ -1482,6 +1584,10 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1482
  current_analysis = best_batch["batch_analysis"]
1483
  current_semantic = best_batch["batch_semantic"]
1484
  current_metrics = best_batch["batch_metrics"]
 
 
 
 
1485
  applied_changes += 1
1486
  batch_applied = True
1487
  batch_info = {
@@ -1509,6 +1615,7 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1509
  {
1510
  "step": step + 1,
1511
  "status": "applied_batch",
 
1512
  "goal": goal,
1513
  "cascade_level": cascade_level,
1514
  "operation": "batch",
@@ -1553,6 +1660,7 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1553
  {
1554
  "step": step + 1,
1555
  "status": "rejected",
 
1556
  "goal": goal,
1557
  "cascade_level": cascade_level,
1558
  "operation": primary_span.get("operation"),
@@ -1597,6 +1705,12 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1597
  ],
1598
  }
1599
  )
 
 
 
 
 
 
1600
  consecutive_failures += 1
1601
  if consecutive_failures >= 2 and cascade_level < 4:
1602
  cascade_level += 1
@@ -1621,6 +1735,10 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1621
  current_analysis = best["analysis"]
1622
  current_semantic = best["semantic"]
1623
  current_metrics = best["metrics"]
 
 
 
 
1624
  applied_changes += 1
1625
  queued_candidates = []
1626
 
@@ -1628,6 +1746,7 @@ def optimize_text(request_data: Dict[str, Any]) -> Dict[str, Any]:
1628
  {
1629
  "step": step + 1,
1630
  "status": "applied",
 
1631
  "goal": goal,
1632
  "cascade_level": cascade_level,
1633
  "operation": best.get("operation"),
 
22
 
23
  BERT_TARGET_THRESHOLD = 0.7
24
  BERT_GOAL_DELTA_MIN = 0.005
25
+ TITLE_TARGET_THRESHOLD = 0.65
26
  SEMANTIC_GAP_TOLERANCE_PCT = 0.15
27
  SEMANTIC_GAP_MIN_ABS = 3.0
28
+ STAGE_ORDER = ["bert", "bm25", "semantic", "ngram", "title"]
29
 
30
 
31
  def _tokenize(text: str) -> List[str]:
 
340
  }
341
 
342
 
343
+ def _choose_optimization_goal(
344
+ analysis: Dict[str, Any],
345
+ semantic: Dict[str, Any],
346
+ keywords: List[str],
347
+ language: str,
348
+ stage: str = "bert",
349
+ ) -> Dict[str, Any]:
350
+ candidates: Dict[str, Dict[str, Any]] = {}
351
  bert_details = analysis.get("bert_analysis", {}).get("detailed", []) or []
352
  low_bert = [x for x in bert_details if float(x.get("my_max_score", 0)) < BERT_TARGET_THRESHOLD]
353
  if low_bert:
354
  worst = sorted(low_bert, key=lambda x: float(x.get("my_max_score", 0)))[0]
355
  focus_terms = _filter_stopwords(_tokenize(worst.get("phrase", "")), language)[:4]
356
+ candidates["bert"] = {"type": "bert", "label": str(worst.get("phrase", "")), "focus_terms": focus_terms, "avoid_terms": []}
357
 
358
  bm25_remove = [x for x in (analysis.get("bm25_recommendations") or []) if x.get("action") == "remove"]
359
  if len(bm25_remove) >= 4:
360
  spam_terms = [str(x.get("word", "")) for x in sorted(bm25_remove, key=lambda r: int(r.get("count", 0)), reverse=True)[:4]]
361
+ candidates["bm25"] = {"type": "bm25", "label": "reduce spam", "focus_terms": [], "avoid_terms": spam_terms}
362
 
363
  # Semantic keyword gaps
364
  lang_stop = STOP_WORDS.get(language, STOP_WORDS["en"])
 
382
  candidate_rows.append((term, gap))
383
  if candidate_rows:
384
  top_term = sorted(candidate_rows, key=lambda x: x[1], reverse=True)[0][0]
385
+ candidates["semantic"] = {"type": "semantic", "label": top_term, "focus_terms": [top_term], "avoid_terms": []}
386
 
387
  # Fallback: ngram add signal
388
  for bucket_name in ("unigrams", "bigrams"):
 
391
  target = float(item.get("target_count", 0))
392
  comp_avg = float(item.get("competitor_avg", 0))
393
  if (target == 0 and comp_avg > 0) or (target > 0 and comp_avg >= target * 2):
394
+ candidates["ngram"] = {
395
+ "type": "ngram",
396
+ "label": str(item.get("ngram", "")),
397
+ "focus_terms": _tokenize(str(item.get("ngram", "")))[:3],
398
+ "avoid_terms": [],
399
+ }
400
+ break
401
+ if "ngram" in candidates:
402
+ break
403
+
404
+ title_bert = analysis.get("title_analysis", {}).get("bert", {}) or {}
405
+ title_target_score = title_bert.get("target_score")
406
+ if title_target_score is not None and float(title_target_score) < TITLE_TARGET_THRESHOLD:
407
+ candidates["title"] = {
408
+ "type": "title",
409
+ "label": "title alignment",
410
+ "focus_terms": _filter_stopwords(_tokenize(" ".join(keywords[:2])), language)[:4],
411
+ "avoid_terms": [],
412
+ }
413
+
414
+ if stage in candidates:
415
+ return candidates[stage]
416
 
417
  return {"type": "none", "label": "no-op", "focus_terms": [], "avoid_terms": []}
418
 
 
936
  return 0.0
937
 
938
 
939
+ def _stage_primary_progress(stage: str, prev_metrics: Dict[str, Any], next_metrics: Dict[str, Any]) -> bool:
940
+ if stage == "bert":
941
+ prev_low = int(prev_metrics.get("bert_low_count", 0))
942
+ next_low = int(next_metrics.get("bert_low_count", 0))
943
+ if next_low < prev_low:
944
+ return True
945
+ prev_max = max([0.0] + [float(v) for v in (prev_metrics.get("bert_phrase_scores") or {}).values()])
946
+ next_max = max([0.0] + [float(v) for v in (next_metrics.get("bert_phrase_scores") or {}).values()])
947
+ return (next_max - prev_max) >= BERT_GOAL_DELTA_MIN
948
+ if stage == "bm25":
949
+ return int(next_metrics.get("bm25_remove_count", 0)) < int(prev_metrics.get("bm25_remove_count", 0))
950
+ if stage == "semantic":
951
+ return (
952
+ int(next_metrics.get("semantic_gap_count", 0)) < int(prev_metrics.get("semantic_gap_count", 0))
953
+ or float(next_metrics.get("semantic_gap_sum", 0.0)) < float(prev_metrics.get("semantic_gap_sum", 0.0))
954
+ )
955
+ if stage == "ngram":
956
+ return int(next_metrics.get("ngram_signal_count", 0)) < int(prev_metrics.get("ngram_signal_count", 0))
957
+ if stage == "title":
958
+ pv = prev_metrics.get("title_bert_score")
959
+ nv = next_metrics.get("title_bert_score")
960
+ if pv is None or nv is None:
961
+ return False
962
+ return float(nv) > float(pv)
963
+ return False
964
+
965
+
966
+ def _is_stage_complete(stage: str, metrics: Dict[str, Any]) -> bool:
967
+ if stage == "bert":
968
+ return int(metrics.get("bert_low_count", 0)) == 0
969
+ if stage == "bm25":
970
+ return int(metrics.get("bm25_remove_count", 0)) <= 3
971
+ if stage == "semantic":
972
+ return int(metrics.get("semantic_gap_count", 0)) <= 0
973
+ if stage == "ngram":
974
+ return int(metrics.get("ngram_signal_count", 0)) <= 0
975
+ if stage == "title":
976
+ score = metrics.get("title_bert_score")
977
+ return (score is None) or (float(score) >= TITLE_TARGET_THRESHOLD)
978
+ return True
979
+
980
+
981
  def _candidate_utility(
982
  *,
983
  prev_metrics: Dict[str, Any],
 
1108
  goal_attempt_cursor: Dict[str, int] = {}
1109
  attempted_spans = set()
1110
  queued_candidates: List[Dict[str, Any]] = []
1111
+ stage_idx = 0
1112
+ stage_no_progress_steps = 0
1113
 
1114
  for step in range(max_iterations):
1115
+ while stage_idx < len(STAGE_ORDER) and _is_stage_complete(STAGE_ORDER[stage_idx], current_metrics):
1116
+ stage_idx += 1
1117
+ stage_no_progress_steps = 0
1118
+ if stage_idx >= len(STAGE_ORDER):
1119
+ logs.append({"step": step + 1, "status": "stopped", "reason": "All optimization stages completed."})
1120
  break
1121
 
1122
+ active_stage = STAGE_ORDER[stage_idx]
1123
+ goal = _choose_optimization_goal(
1124
+ current_analysis,
1125
+ current_semantic,
1126
+ keywords,
1127
+ language,
1128
+ stage=active_stage,
1129
+ )
1130
+ if goal["type"] == "none":
1131
+ stage_idx += 1
1132
+ stage_no_progress_steps = 0
1133
+ logs.append(
1134
+ {
1135
+ "step": step + 1,
1136
+ "status": "stage_skipped",
1137
+ "stage": active_stage,
1138
+ "reason": f"No actionable goals for stage '{active_stage}', moving to next stage.",
1139
+ }
1140
+ )
1141
+ continue
1142
+
1143
  sentences = _split_sentences(current_text)
1144
  if not sentences:
1145
  logs.append({"step": step + 1, "status": "stopped", "reason": "No sentences available for editing."})
 
1444
  current_analysis = best_local["analysis"]
1445
  current_semantic = best_local["semantic"]
1446
  current_metrics = best_local["metrics"]
1447
+ if _stage_primary_progress(active_stage, prev_metrics, current_metrics):
1448
+ stage_no_progress_steps = 0
1449
+ else:
1450
+ stage_no_progress_steps += 1
1451
  applied_changes += 1
1452
  queued_candidates = []
1453
 
 
1455
  {
1456
  "step": step + 1,
1457
  "status": "applied_local_progress",
1458
+ "stage": active_stage,
1459
  "goal": goal,
1460
  "cascade_level": cascade_level,
1461
  "operation": best_local.get("operation"),
 
1584
  current_analysis = best_batch["batch_analysis"]
1585
  current_semantic = best_batch["batch_semantic"]
1586
  current_metrics = best_batch["batch_metrics"]
1587
+ if _stage_primary_progress(active_stage, prev_metrics, current_metrics):
1588
+ stage_no_progress_steps = 0
1589
+ else:
1590
+ stage_no_progress_steps += 1
1591
  applied_changes += 1
1592
  batch_applied = True
1593
  batch_info = {
 
1615
  {
1616
  "step": step + 1,
1617
  "status": "applied_batch",
1618
+ "stage": active_stage,
1619
  "goal": goal,
1620
  "cascade_level": cascade_level,
1621
  "operation": "batch",
 
1660
  {
1661
  "step": step + 1,
1662
  "status": "rejected",
1663
+ "stage": active_stage,
1664
  "goal": goal,
1665
  "cascade_level": cascade_level,
1666
  "operation": primary_span.get("operation"),
 
1705
  ],
1706
  }
1707
  )
1708
+ stage_no_progress_steps += 1
1709
+ if stage_no_progress_steps >= 3 and stage_idx < len(STAGE_ORDER) - 1:
1710
+ stage_idx += 1
1711
+ stage_no_progress_steps = 0
1712
+ logs[-1]["advanced_to_stage"] = STAGE_ORDER[stage_idx]
1713
+ logs[-1]["reason"] = f"{logs[-1].get('reason', '-') } Stage plateau: no primary progress for 3 steps."
1714
  consecutive_failures += 1
1715
  if consecutive_failures >= 2 and cascade_level < 4:
1716
  cascade_level += 1
 
1735
  current_analysis = best["analysis"]
1736
  current_semantic = best["semantic"]
1737
  current_metrics = best["metrics"]
1738
+ if _stage_primary_progress(active_stage, prev_metrics, current_metrics):
1739
+ stage_no_progress_steps = 0
1740
+ else:
1741
+ stage_no_progress_steps += 1
1742
  applied_changes += 1
1743
  queued_candidates = []
1744
 
 
1746
  {
1747
  "step": step + 1,
1748
  "status": "applied",
1749
+ "stage": active_stage,
1750
  "goal": goal,
1751
  "cascade_level": cascade_level,
1752
  "operation": best.get("operation"),
templates/index.html CHANGED
@@ -871,6 +871,7 @@
871
  const candidateRows = candidates.map(c => {
872
  const reasons = Array.isArray(c.invalid_reasons) ? c.invalid_reasons.join(', ') : '';
873
  const sentAfter = c.sentence_after ? safeHtml(c.sentence_after) : '-';
 
874
  const relBefore = (c.chunk_relevance_before ?? '-');
875
  const relAfter = (c.chunk_relevance_after ?? '-');
876
  const termDiff = c.term_diff ? safeHtml(JSON.stringify(c.term_diff)) : '-';
@@ -880,6 +881,7 @@
880
  return `
881
  <tr>
882
  <td>${c.candidate_index ?? '-'}</td>
 
883
  <td>${c.valid ? 'yes' : 'no'}</td>
884
  <td>${c.goal_improved ? 'yes' : 'no'}</td>
885
  <td>${c.bert_phrase_delta ?? '-'}</td>
@@ -926,10 +928,10 @@
926
  <table class="table table-sm table-bordered mb-0">
927
  <thead class="table-light">
928
  <tr>
929
- <th>#cand</th><th>valid</th><th>goal+</th><th>bert Δ</th><th>local+</th><th>chunk Δ</th><th>rel b→a</th><th>Δ</th><th>score</th><th>reject reason/error</th><th>кандидат правки</th>
930
  </tr>
931
  </thead>
932
- <tbody>${candidateRows || '<tr><td colspan="11" class="text-center text-muted">Нет кандидатов</td></tr>'}</tbody>
933
  </table>
934
  </div>
935
  </div>
 
871
  const candidateRows = candidates.map(c => {
872
  const reasons = Array.isArray(c.invalid_reasons) ? c.invalid_reasons.join(', ') : '';
873
  const sentAfter = c.sentence_after ? safeHtml(c.sentence_after) : '-';
874
+ const strategy = c.phrase_strategy_used || (c.llm_prompt_debug && c.llm_prompt_debug.phrase_strategy_mode) || '-';
875
  const relBefore = (c.chunk_relevance_before ?? '-');
876
  const relAfter = (c.chunk_relevance_after ?? '-');
877
  const termDiff = c.term_diff ? safeHtml(JSON.stringify(c.term_diff)) : '-';
 
881
  return `
882
  <tr>
883
  <td>${c.candidate_index ?? '-'}</td>
884
+ <td>${safeHtml(strategy)}</td>
885
  <td>${c.valid ? 'yes' : 'no'}</td>
886
  <td>${c.goal_improved ? 'yes' : 'no'}</td>
887
  <td>${c.bert_phrase_delta ?? '-'}</td>
 
928
  <table class="table table-sm table-bordered mb-0">
929
  <thead class="table-light">
930
  <tr>
931
+ <th>#cand</th><th>strategy</th><th>valid</th><th>goal+</th><th>bert Δ</th><th>local+</th><th>chunk Δ</th><th>rel b→a</th><th>Δ</th><th>score</th><th>reject reason/error</th><th>кандидат правки</th>
932
  </tr>
933
  </thead>
934
+ <tbody>${candidateRows || '<tr><td colspan="12" class="text-center text-muted">Нет кандидатов</td></tr>'}</tbody>
935
  </table>
936
  </div>
937
  </div>