File size: 1,489 Bytes
f25adba
 
70d4442
f25adba
 
 
70d4442
f25adba
 
 
 
 
 
70d4442
f25adba
 
 
 
 
 
 
 
 
 
 
 
70d4442
f25adba
70d4442
f25adba
 
 
 
 
70d4442
f25adba
 
70d4442
f25adba
 
 
70d4442
f25adba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70d4442
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import time
from dataclasses import dataclass
from typing import Dict, Any, Optional, Callable

from .backends import LLMBackend, make_backend
from .config import ModelSpec
from .tasks import TaskContext


@dataclass
class AgentResult:
    role: str
    model_id: str
    backend: str
    prompt: str
    system: Optional[str]
    output: str
    elapsed_s: float
    params: Dict[str, Any]


class BaseAgent:
    role: str = "agent"

    def __init__(self, spec: ModelSpec):
        self.spec = spec
        self.backend: LLMBackend = make_backend(spec.backend, spec.model_id)

    def run(self, prompt: str) -> AgentResult:
        t0 = time.time()
        params = {
            "temperature": self.spec.temperature,
            "max_new_tokens": self.spec.max_new_tokens,
            "top_p": self.spec.top_p,
            "repetition_penalty": self.spec.repetition_penalty,
            **(self.spec.extra or {}),
        }
        out = self.backend.generate(prompt, system=self.spec.system_prompt, params=params)
        return AgentResult(
            role=self.role,
            model_id=self.spec.model_id,
            backend=self.spec.backend,
            prompt=prompt,
            system=self.spec.system_prompt,
            output=out,
            elapsed_s=time.time() - t0,
            params=params,
        )


class AnalyzerAgent(BaseAgent):
    role = "analyzer"


class RefactorAgent(BaseAgent):
    role = "refactor"


class CriticAgent(BaseAgent):
    role = "critic"