File size: 7,198 Bytes
13bc746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
Triangle — The core engine.
Three models. One question. The disagreement is the data.
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional
from nova_triangle.result import TriangleResult


class Triangle:
    """
    Triangulated inference across three language models.

    Instead of asking one model and trusting the answer, we ask three.
    One proposes (steers). Two evaluate. If they converge, high confidence.
    If they diverge, the disagreement itself is useful data.

    The steering role rotates. No model is always the boss.
    """

    def __init__(
        self,
        models: List[str],
        device: Optional[str] = None,
        dtype: torch.dtype = torch.float16,
        max_tokens: int = 200,
        max_rounds: int = 3,
        convergence_threshold: float = 0.7,
    ):
        if len(models) != 3:
            raise ValueError("Triangle requires exactly 3 models. That's the whole point.")

        self.model_names = models
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.max_tokens = max_tokens
        self.max_rounds = max_rounds
        self.convergence_threshold = convergence_threshold
        self._steer_index = 0

        self.models = []
        self.tokenizers = []

        for name in models:
            tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
            if tok.pad_token is None:
                tok.pad_token = tok.eos_token
            model = AutoModelForCausalLM.from_pretrained(
                name, torch_dtype=dtype, trust_remote_code=True
            ).to(self.device)
            model.eval()
            self.tokenizers.append(tok)
            self.models.append(model)

    def _generate(self, model_idx: int, prompt: str) -> str:
        """Ask one model, get its raw answer."""
        tok = self.tokenizers[model_idx]
        model = self.models[model_idx]
        inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=self.max_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tok.pad_token_id,
            )
        response = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
        return response.strip()

    def _similarity(self, a: str, b: str) -> float:
        """
        Quick semantic similarity between two responses.
        Word overlap ratio. Not perfect, but fast and sufficient for convergence detection.
        LB can swap in embedding-based similarity when benchmarks are ready.
        """
        words_a = set(a.lower().split())
        words_b = set(b.lower().split())
        if not words_a or not words_b:
            return 0.0
        intersection = words_a & words_b
        union = words_a | words_b
        return len(intersection) / len(union)

    def _check_convergence(self, responses: List[str]) -> tuple:
        """
        Do the three responses agree?
        Returns (converged: bool, confidence: float, disagreement: dict)
        """
        sims = []
        for i in range(3):
            for j in range(i + 1, 3):
                sims.append(self._similarity(responses[i], responses[j]))

        avg_sim = sum(sims) / len(sims)
        converged = avg_sim >= self.convergence_threshold

        disagreement = {}
        if not converged:
            # Find who disagreed most
            min_sim_idx = sims.index(min(sims))
            pairs = [(0, 1), (0, 2), (1, 2)]
            i, j = pairs[min_sim_idx]
            disagreement[self.model_names[i]] = responses[i]
            disagreement[self.model_names[j]] = responses[j]

        return converged, avg_sim, disagreement

    def process(self, prompt: str) -> TriangleResult:
        """
        Run triangulated inference.

        One model steers (proposes). All three answer. Check convergence.
        If they disagree, the disagreement is returned — it's signal, not failure.
        """
        steer = self._steer_index
        self._steer_index = (self._steer_index + 1) % 3

        best_responses = None
        best_confidence = 0.0
        best_converged = False
        best_disagreement = {}

        for round_num in range(1, self.max_rounds + 1):
            if round_num == 1:
                # First round: all three answer independently
                responses = [self._generate(i, prompt) for i in range(3)]
            else:
                # Subsequent rounds: include the steering model's previous answer as context
                steer_answer = best_responses[steer]
                augmented = (
                    f"{prompt}\n\n"
                    f"A previous analysis suggested: {steer_answer}\n"
                    f"Do you agree, disagree, or have a different perspective?"
                )
                responses = [self._generate(i, augmented) for i in range(3)]

            converged, confidence, disagreement = self._check_convergence(responses)

            if confidence > best_confidence:
                best_responses = responses
                best_confidence = confidence
                best_converged = converged
                best_disagreement = disagreement

            if converged:
                break

        # The answer is the steering model's response (it proposed, others validated)
        answer = best_responses[steer]

        # Generate flag if disagreement was significant
        flag = None
        if not best_converged and best_confidence < 0.4:
            flag = (
                f"High disagreement (confidence {best_confidence:.2f}). "
                f"The models found something worth examining manually."
            )

        return TriangleResult(
            answer=answer,
            confidence=best_confidence,
            converged=best_converged,
            disagreement=best_disagreement,
            flag=flag,
            raw_responses=best_responses,
            steering_model=self.model_names[steer],
            rounds=round_num,
        )

    def process_batch(self, prompts: List[str]) -> List[TriangleResult]:
        """Process multiple prompts. Flags accumulate — patterns in disagreement are data."""
        return [self.process(p) for p in prompts]

    def report(self, result: TriangleResult) -> str:
        """Human-readable summary of a triangle result."""
        lines = [
            f"Steered by: {result.steering_model}",
            f"Converged: {'Yes' if result.converged else 'No'} ({result.rounds} round{'s' if result.rounds > 1 else ''})",
            f"Confidence: {result.confidence:.1%}",
            f"Answer: {result.answer[:200]}{'...' if len(result.answer) > 200 else ''}",
        ]
        if result.flag:
            lines.append(f"FLAG: {result.flag}")
        if result.disagreement:
            lines.append("Disagreement:")
            for model, resp in result.disagreement.items():
                lines.append(f"  {model}: {resp[:100]}...")
        return "\n".join(lines)