File size: 10,481 Bytes
31d4d14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""M貌dul per a l'agent d'"introspection".



Implementa:



- Un proc茅s d'entrenament que apr猫n de les correccions HITL comparant

  `une_ad` autom脿tic (MoE/Salamandra) amb `une_ad` de la versi贸 HITL.

- Un pas d'introspecci贸 que aplica aquestes regles a un nou SRT utilitzant

  GPT-4o-mini.

"""

from __future__ import annotations

import json
import logging
import os
import sqlite3
from pathlib import Path
from typing import Iterable, List, Optional, Tuple

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage


logger = logging.getLogger(__name__)


# --- Rutes i constants ---

BASE_DIR = Path(__file__).resolve().parent
# Estructura esperada: .../hf_spaces/engine/refinement/introspection.py
# Per tant, la "root" del repo 茅s el pare immediat de "engine".
REPO_ROOT = BASE_DIR.parents[1]
DEMO_DIR = REPO_ROOT / "demo"
DEMO_TEMP_DIR = DEMO_DIR / "temp"

REFINEMENT_TEMP_DIR = BASE_DIR / "temp"
REFINEMENT_TEMP_DIR.mkdir(exist_ok=True, parents=True)

FEW_SHOT_PATH = REFINEMENT_TEMP_DIR / "few_shot_examples.txt"
RULES_PATH = REFINEMENT_TEMP_DIR / "rules.txt"

AUDIODESCRIPTIONS_DB_PATH = DEMO_TEMP_DIR / "audiodescriptions.db"


def _get_llm() -> Optional[ChatOpenAI]:
    """Retorna una inst脿ncia de GPT-4o-mini o None si no hi ha API key."""

    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        logger.warning("OPENAI_API_KEY no est谩 configurada; se omite la introspection.")
        return None
    try:
        return ChatOpenAI(model="gpt-4o-mini", temperature=0.0, api_key=api_key)
    except Exception as exc:  # pragma: no cover - errors de client extern
        logger.error("No se pudo inicializar ChatOpenAI para introspection: %s", exc)
        return None


# --- Lectura de dades d'entrenament ---

def _iter_une_vs_hitl_pairs() -> Iterable[Tuple[str, str, str]]:
    """Itera sobre (sha1sum, une_ad_auto, une_ad_hitl).



    A partir d'ara:

    - une_ad_auto: versi贸 autom脿tica (MoE o Salamandra), camp ``une_ad``.

    - une_ad_hitl: versi贸 corregida HITL guardada al mateix registre, camp ``ok_une_ad``.

    """

    if not AUDIODESCRIPTIONS_DB_PATH.exists():
        logger.warning("audiodescriptions.db no encontrado en %s", AUDIODESCRIPTIONS_DB_PATH)
        return

    conn = sqlite3.connect(str(AUDIODESCRIPTIONS_DB_PATH))
    conn.row_factory = sqlite3.Row
    try:
        cur = conn.cursor()
        try:
            cur.execute(
                """

                SELECT sha1sum, version, une_ad, ok_une_ad

                FROM audiodescriptions

                WHERE version IN ('MoE', 'Salamandra')

                """
            )
        except sqlite3.OperationalError:
            logger.warning("Tabla audiodescriptions no disponible en %s", AUDIODESCRIPTIONS_DB_PATH)
            return

        rows = cur.fetchall()
        for row in rows:
            sha1sum = row["sha1sum"]
            une_auto = (row["une_ad"] or "").strip()
            une_hitl = (row["ok_une_ad"] or "").strip() if "ok_une_ad" in row.keys() else ""

            if not une_auto or not une_hitl:
                continue

            if une_hitl == une_auto:
                # No hi ha difer猫ncies; no aporta informaci贸
                continue

            yield sha1sum, une_auto, une_hitl
    finally:
        conn.close()


def _strip_markdown_fences(content: str) -> str:
    """Elimina fences ```...``` alrededor de una respuesta JSON si existen."""

    text = content.strip()
    if text.startswith("```"):
        lines = text.splitlines()
        # descartar primera l铆nea con ``` o ```json
        lines = lines[1:]
        # eliminar el cierre ``` (pueden existir varias l铆neas en blanco finales)
        while lines and lines[-1].strip().startswith("```"):
            lines.pop()
        text = "\n".join(lines).strip()
    return text


def _analyze_correction_with_llm(llm: ChatOpenAI, une_auto: str, une_hitl: str) -> Tuple[str, str]:
    """Demana al LLM que descrigui la correcci贸 i extregui una regla general.



    Retorna (few_shot_example, rule). Si falla, retorna cadenes buides.

    """

    system = SystemMessage(
        content=(
            "Ets un assistent que analitza correccions d'audiodescripcions UNE-153010. "
            "Se't dona una versi贸 autom脿tica i una versi贸 corregida per humans (HITL). "
            "La teva tasca 茅s (1) descriure de forma concisa qu猫 s'ha corregit, amb "
            "exemples concrets, i (2) proposar una regla general aplicable a futurs SRT. "
            "Respon en format JSON amb les claus 'few_shot_example' i 'rule'."
        )
    )

    user_content = {
        "une_ad_auto": une_auto,
        "une_ad_hitl": une_hitl,
    }

    msg = HumanMessage(content=json.dumps(user_content, ensure_ascii=False))

    try:
        resp = llm.invoke([system, msg])
    except Exception as exc:  # pragma: no cover - errors externs
        logger.error("Error llamando al LLM en introspection training: %s", exc)
        return "", ""

    raw = resp.content if isinstance(resp.content, str) else str(resp.content)
    text = _strip_markdown_fences(raw)
    try:
        data = json.loads(text)
    except json.JSONDecodeError:
        logger.warning("La respuesta del LLM no es JSON v谩lido: %s", raw[:2000])
        return raw.strip(), ""

    few = data.get("few_shot_example", "")
    # Aceptamos tanto string como objeto; si es objeto, lo "bonificamos" a texto legible
    if isinstance(few, dict):
        try:
            few_shot = json.dumps(few, ensure_ascii=False, indent=2)
        except Exception:
            few_shot = str(few)
    else:
        few_shot = str(few)

    rule = str(data.get("rule", "")).strip()
    return few_shot.strip(), rule


def train_introspection_rules(max_examples: Optional[int] = None) -> None:
    """Entrena regles d'introspecci贸 a partir de les correccions HITL.



    - Recorre audiodescriptions.db buscant parelles (MoE/Salamandra, HITL).

    - Per a cada parella amb difer猫ncies significatives, demana al LLM:

        * Un "few_shot_example" que descrigui la correcci贸.

        * Una "rule" generalitzada.

    - Afegeix els exemples a ``few_shot_examples.txt`` i les regles 煤niques a

      ``rules.txt`` dins de ``engine/refinement/temp``.

    """

    llm = _get_llm()
    if llm is None:
        logger.info("Introspection training skipped: no LLM available.")
        return

    logger.info("Comen莽ant entrenament d'introspection a partir de %s", AUDIODESCRIPTIONS_DB_PATH)

    # Carregar regles existents per no duplicar-les
    existing_rules: List[str] = []
    if RULES_PATH.exists():
        try:
            existing_rules = [line.strip() for line in RULES_PATH.read_text(encoding="utf-8").splitlines() if line.strip()]
        except Exception:
            existing_rules = []

    seen_rules = set(existing_rules)

    n_processed = 0
    n_generated = 0

    with FEW_SHOT_PATH.open("a", encoding="utf-8") as f_examples, RULES_PATH.open(
        "a", encoding="utf-8"
    ) as f_rules:
        for sha1sum, une_auto, une_hitl in _iter_une_vs_hitl_pairs():
            if max_examples is not None and n_processed >= max_examples:
                break

            n_processed += 1
            logger.info("Analitzant correcci贸 HITL per sha1sum=%s", sha1sum)

            few_shot, rule = _analyze_correction_with_llm(llm, une_auto, une_hitl)
            if not few_shot and not rule:
                continue

            if few_shot:
                f_examples.write("# sha1sum=" + sha1sum + "\n")
                f_examples.write(few_shot + "\n\n")

            if rule and rule not in seen_rules:
                seen_rules.add(rule)
                f_rules.write(rule + "\n")

            n_generated += 1

    logger.info(
        "Introspection training completat: %d parelles processades, %d entrades generades",
        n_processed,
        n_generated,
    )


def _load_text_file(path: Path) -> str:
    if not path.exists():
        return ""
    try:
        return path.read_text(encoding="utf-8")
    except Exception:
        return ""


def refine_srt_with_introspection(srt_content: str) -> str:
    """Aplica el pas d'introspecci贸 sobre un SRT.



    - Llegeix ``few_shot_examples.txt`` i ``rules.txt`` de ``engine/refinement/temp``.

    - Demana a GPT-4o-mini que corregeixi el SRT tenint en compte aquests

      exemples i regles.

    - Si no hi ha LLM o fitxers, retorna el SRT original.

    """

    llm = _get_llm()
    if llm is None:
        return srt_content

    few_shots = _load_text_file(FEW_SHOT_PATH)
    rules = _load_text_file(RULES_PATH)

    if not few_shots and not rules:
        # Res a aplicar; no modifiquem el SRT
        return srt_content

    system_parts: List[str] = [
        "Ets un assistent que millora audiodescripcions en format SRT.",
        "Tens unes regles d'introspecci贸 derivades de correccions humanes (HITL)",
        "i alguns exemples de correccions anteriors (few-shot examples).",
        "Has de produir un nou SRT que apliqui aquestes regles i millores,",
        "mantenint l'estructura de temps i el format SRT.",
        "Retorna 煤nicament el SRT corregit, sense explicacions addicionals.",
    ]

    if rules:
        system_parts.append("\nRegles d'introspecci贸 (una per l铆nia):\n" + rules)

    if few_shots:
        system_parts.append("\nExemples de correccions (few-shot examples):\n" + few_shots)

    system_msg = SystemMessage(content="\n".join(system_parts))

    user_msg = HumanMessage(
        content=(
            "A continuaci贸 tens un SRT generat autom脿ticament. "
            "Aplica les regles i l'estil observat als exemples per millorar-lo, "
            "especialment en aquells aspectes que solen ser corregits pels humans.\n\n"
            "SRT original:\n" + srt_content
        )
    )

    try:
        resp = llm.invoke([system_msg, user_msg])
    except Exception as exc:  # pragma: no cover - errors externs
        logger.error("Error llamando al LLM en introspection apply: %s", exc)
        return srt_content

    text = resp.content if isinstance(resp.content, str) else str(resp.content)
    return text.strip() or srt_content