| """M貌dul per a l'agent d'"introspection".
|
|
|
| Implementa:
|
|
|
| - Un proc茅s d'entrenament que apr猫n de les correccions HITL comparant
|
| `une_ad` autom脿tic (MoE/Salamandra) amb `une_ad` de la versi贸 HITL.
|
| - Un pas d'introspecci贸 que aplica aquestes regles a un nou SRT utilitzant
|
| GPT-4o-mini.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import json
|
| import logging
|
| import os
|
| import sqlite3
|
| from pathlib import Path
|
| from typing import Iterable, List, Optional, Tuple
|
|
|
| from langchain_openai import ChatOpenAI
|
| from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
| BASE_DIR = Path(__file__).resolve().parent
|
|
|
|
|
| REPO_ROOT = BASE_DIR.parents[1]
|
| DEMO_DIR = REPO_ROOT / "demo"
|
| DEMO_TEMP_DIR = DEMO_DIR / "temp"
|
|
|
| REFINEMENT_TEMP_DIR = BASE_DIR / "temp"
|
| REFINEMENT_TEMP_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
| FEW_SHOT_PATH = REFINEMENT_TEMP_DIR / "few_shot_examples.txt"
|
| RULES_PATH = REFINEMENT_TEMP_DIR / "rules.txt"
|
|
|
| AUDIODESCRIPTIONS_DB_PATH = DEMO_TEMP_DIR / "audiodescriptions.db"
|
|
|
|
|
| def _get_llm() -> Optional[ChatOpenAI]:
|
| """Retorna una inst脿ncia de GPT-4o-mini o None si no hi ha API key."""
|
|
|
| api_key = os.environ.get("OPENAI_API_KEY")
|
| if not api_key:
|
| logger.warning("OPENAI_API_KEY no est谩 configurada; se omite la introspection.")
|
| return None
|
| try:
|
| return ChatOpenAI(model="gpt-4o-mini", temperature=0.0, api_key=api_key)
|
| except Exception as exc:
|
| logger.error("No se pudo inicializar ChatOpenAI para introspection: %s", exc)
|
| return None
|
|
|
|
|
|
|
|
|
| def _iter_une_vs_hitl_pairs() -> Iterable[Tuple[str, str, str]]:
|
| """Itera sobre (sha1sum, une_ad_auto, une_ad_hitl).
|
|
|
| A partir d'ara:
|
| - une_ad_auto: versi贸 autom脿tica (MoE o Salamandra), camp ``une_ad``.
|
| - une_ad_hitl: versi贸 corregida HITL guardada al mateix registre, camp ``ok_une_ad``.
|
| """
|
|
|
| if not AUDIODESCRIPTIONS_DB_PATH.exists():
|
| logger.warning("audiodescriptions.db no encontrado en %s", AUDIODESCRIPTIONS_DB_PATH)
|
| return
|
|
|
| conn = sqlite3.connect(str(AUDIODESCRIPTIONS_DB_PATH))
|
| conn.row_factory = sqlite3.Row
|
| try:
|
| cur = conn.cursor()
|
| try:
|
| cur.execute(
|
| """
|
| SELECT sha1sum, version, une_ad, ok_une_ad
|
| FROM audiodescriptions
|
| WHERE version IN ('MoE', 'Salamandra')
|
| """
|
| )
|
| except sqlite3.OperationalError:
|
| logger.warning("Tabla audiodescriptions no disponible en %s", AUDIODESCRIPTIONS_DB_PATH)
|
| return
|
|
|
| rows = cur.fetchall()
|
| for row in rows:
|
| sha1sum = row["sha1sum"]
|
| une_auto = (row["une_ad"] or "").strip()
|
| une_hitl = (row["ok_une_ad"] or "").strip() if "ok_une_ad" in row.keys() else ""
|
|
|
| if not une_auto or not une_hitl:
|
| continue
|
|
|
| if une_hitl == une_auto:
|
|
|
| continue
|
|
|
| yield sha1sum, une_auto, une_hitl
|
| finally:
|
| conn.close()
|
|
|
|
|
| def _strip_markdown_fences(content: str) -> str:
|
| """Elimina fences ```...``` alrededor de una respuesta JSON si existen."""
|
|
|
| text = content.strip()
|
| if text.startswith("```"):
|
| lines = text.splitlines()
|
|
|
| lines = lines[1:]
|
|
|
| while lines and lines[-1].strip().startswith("```"):
|
| lines.pop()
|
| text = "\n".join(lines).strip()
|
| return text
|
|
|
|
|
| def _analyze_correction_with_llm(llm: ChatOpenAI, une_auto: str, une_hitl: str) -> Tuple[str, str]:
|
| """Demana al LLM que descrigui la correcci贸 i extregui una regla general.
|
|
|
| Retorna (few_shot_example, rule). Si falla, retorna cadenes buides.
|
| """
|
|
|
| system = SystemMessage(
|
| content=(
|
| "Ets un assistent que analitza correccions d'audiodescripcions UNE-153010. "
|
| "Se't dona una versi贸 autom脿tica i una versi贸 corregida per humans (HITL). "
|
| "La teva tasca 茅s (1) descriure de forma concisa qu猫 s'ha corregit, amb "
|
| "exemples concrets, i (2) proposar una regla general aplicable a futurs SRT. "
|
| "Respon en format JSON amb les claus 'few_shot_example' i 'rule'."
|
| )
|
| )
|
|
|
| user_content = {
|
| "une_ad_auto": une_auto,
|
| "une_ad_hitl": une_hitl,
|
| }
|
|
|
| msg = HumanMessage(content=json.dumps(user_content, ensure_ascii=False))
|
|
|
| try:
|
| resp = llm.invoke([system, msg])
|
| except Exception as exc:
|
| logger.error("Error llamando al LLM en introspection training: %s", exc)
|
| return "", ""
|
|
|
| raw = resp.content if isinstance(resp.content, str) else str(resp.content)
|
| text = _strip_markdown_fences(raw)
|
| try:
|
| data = json.loads(text)
|
| except json.JSONDecodeError:
|
| logger.warning("La respuesta del LLM no es JSON v谩lido: %s", raw[:2000])
|
| return raw.strip(), ""
|
|
|
| few = data.get("few_shot_example", "")
|
|
|
| if isinstance(few, dict):
|
| try:
|
| few_shot = json.dumps(few, ensure_ascii=False, indent=2)
|
| except Exception:
|
| few_shot = str(few)
|
| else:
|
| few_shot = str(few)
|
|
|
| rule = str(data.get("rule", "")).strip()
|
| return few_shot.strip(), rule
|
|
|
|
|
| def train_introspection_rules(max_examples: Optional[int] = None) -> None:
|
| """Entrena regles d'introspecci贸 a partir de les correccions HITL.
|
|
|
| - Recorre audiodescriptions.db buscant parelles (MoE/Salamandra, HITL).
|
| - Per a cada parella amb difer猫ncies significatives, demana al LLM:
|
| * Un "few_shot_example" que descrigui la correcci贸.
|
| * Una "rule" generalitzada.
|
| - Afegeix els exemples a ``few_shot_examples.txt`` i les regles 煤niques a
|
| ``rules.txt`` dins de ``engine/refinement/temp``.
|
| """
|
|
|
| llm = _get_llm()
|
| if llm is None:
|
| logger.info("Introspection training skipped: no LLM available.")
|
| return
|
|
|
| logger.info("Comen莽ant entrenament d'introspection a partir de %s", AUDIODESCRIPTIONS_DB_PATH)
|
|
|
|
|
| existing_rules: List[str] = []
|
| if RULES_PATH.exists():
|
| try:
|
| existing_rules = [line.strip() for line in RULES_PATH.read_text(encoding="utf-8").splitlines() if line.strip()]
|
| except Exception:
|
| existing_rules = []
|
|
|
| seen_rules = set(existing_rules)
|
|
|
| n_processed = 0
|
| n_generated = 0
|
|
|
| with FEW_SHOT_PATH.open("a", encoding="utf-8") as f_examples, RULES_PATH.open(
|
| "a", encoding="utf-8"
|
| ) as f_rules:
|
| for sha1sum, une_auto, une_hitl in _iter_une_vs_hitl_pairs():
|
| if max_examples is not None and n_processed >= max_examples:
|
| break
|
|
|
| n_processed += 1
|
| logger.info("Analitzant correcci贸 HITL per sha1sum=%s", sha1sum)
|
|
|
| few_shot, rule = _analyze_correction_with_llm(llm, une_auto, une_hitl)
|
| if not few_shot and not rule:
|
| continue
|
|
|
| if few_shot:
|
| f_examples.write("# sha1sum=" + sha1sum + "\n")
|
| f_examples.write(few_shot + "\n\n")
|
|
|
| if rule and rule not in seen_rules:
|
| seen_rules.add(rule)
|
| f_rules.write(rule + "\n")
|
|
|
| n_generated += 1
|
|
|
| logger.info(
|
| "Introspection training completat: %d parelles processades, %d entrades generades",
|
| n_processed,
|
| n_generated,
|
| )
|
|
|
|
|
| def _load_text_file(path: Path) -> str:
|
| if not path.exists():
|
| return ""
|
| try:
|
| return path.read_text(encoding="utf-8")
|
| except Exception:
|
| return ""
|
|
|
|
|
| def refine_srt_with_introspection(srt_content: str) -> str:
|
| """Aplica el pas d'introspecci贸 sobre un SRT.
|
|
|
| - Llegeix ``few_shot_examples.txt`` i ``rules.txt`` de ``engine/refinement/temp``.
|
| - Demana a GPT-4o-mini que corregeixi el SRT tenint en compte aquests
|
| exemples i regles.
|
| - Si no hi ha LLM o fitxers, retorna el SRT original.
|
| """
|
|
|
| llm = _get_llm()
|
| if llm is None:
|
| return srt_content
|
|
|
| few_shots = _load_text_file(FEW_SHOT_PATH)
|
| rules = _load_text_file(RULES_PATH)
|
|
|
| if not few_shots and not rules:
|
|
|
| return srt_content
|
|
|
| system_parts: List[str] = [
|
| "Ets un assistent que millora audiodescripcions en format SRT.",
|
| "Tens unes regles d'introspecci贸 derivades de correccions humanes (HITL)",
|
| "i alguns exemples de correccions anteriors (few-shot examples).",
|
| "Has de produir un nou SRT que apliqui aquestes regles i millores,",
|
| "mantenint l'estructura de temps i el format SRT.",
|
| "Retorna 煤nicament el SRT corregit, sense explicacions addicionals.",
|
| ]
|
|
|
| if rules:
|
| system_parts.append("\nRegles d'introspecci贸 (una per l铆nia):\n" + rules)
|
|
|
| if few_shots:
|
| system_parts.append("\nExemples de correccions (few-shot examples):\n" + few_shots)
|
|
|
| system_msg = SystemMessage(content="\n".join(system_parts))
|
|
|
| user_msg = HumanMessage(
|
| content=(
|
| "A continuaci贸 tens un SRT generat autom脿ticament. "
|
| "Aplica les regles i l'estil observat als exemples per millorar-lo, "
|
| "especialment en aquells aspectes que solen ser corregits pels humans.\n\n"
|
| "SRT original:\n" + srt_content
|
| )
|
| )
|
|
|
| try:
|
| resp = llm.invoke([system_msg, user_msg])
|
| except Exception as exc:
|
| logger.error("Error llamando al LLM en introspection apply: %s", exc)
|
| return srt_content
|
|
|
| text = resp.content if isinstance(resp.content, str) else str(resp.content)
|
| return text.strip() or srt_content
|
|
|