File size: 2,440 Bytes
956120a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Post-process LLM answers: citations, formatting, light quality heuristics."""

from __future__ import annotations

import re
from typing import Any, Dict, List, Sequence

_DOC_CITATION_RE = re.compile(r"\[Doc\s+([^\]]+)\]", re.IGNORECASE)
_BRACKET_RE = re.compile(r"\[([^\]]+)\]")


class ResponseProcessor:
    """Citation extraction and simple quality scoring."""

    @staticmethod
    def extract_citations(response: str, documents: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Find [Doc id] citations and map to known chunk ids when possible."""
        known_ids = {str(d.get("id")) for d in documents}
        found: List[Dict[str, Any]] = []
        seen: set[str] = set()
        for m in _DOC_CITATION_RE.finditer(response or ""):
            raw_id = m.group(1).strip()
            if raw_id in seen:
                continue
            seen.add(raw_id)
            found.append(
                {
                    "id": raw_id,
                    "resolved": raw_id in known_ids,
                    "span": m.group(0),
                }
            )
        # Secondary: bare bracket ids that match a chunk id
        for m in _BRACKET_RE.finditer(response or ""):
            inner = m.group(1).strip()
            if inner.lower().startswith("doc "):
                continue
            if inner in known_ids and inner not in seen:
                seen.add(inner)
                found.append({"id": inner, "resolved": True, "span": m.group(0)})
        return found

    @staticmethod
    def format_response(response: str) -> str:
        text = (response or "").strip()
        text = re.sub(r"\r\n", "\n", text)
        text = re.sub(r"\n{3,}", "\n\n", text)
        return text

    @staticmethod
    def score_quality(response: str, query: str) -> float:
        """Heuristic 0..1: length, citation count, token overlap with query."""
        text = (response or "").strip()
        if len(text) < 20:
            return 0.0
        q_terms = {w for w in re.findall(r"[a-z0-9]+", query.lower()) if len(w) > 2}
        r_terms = {w for w in re.findall(r"[a-z0-9]+", text.lower()) if len(w) > 2}
        overlap = (len(q_terms & r_terms) / max(len(q_terms), 1)) if q_terms else 0.3
        cite_bonus = min(0.35, 0.07 * len(_DOC_CITATION_RE.findall(text)))
        length_score = min(0.35, len(text) / 800.0)
        return max(0.0, min(1.0, 0.25 + overlap * 0.45 + cite_bonus + length_score))