File size: 1,489 Bytes
f25adba 70d4442 f25adba 70d4442 f25adba 70d4442 f25adba 70d4442 f25adba 70d4442 f25adba 70d4442 f25adba 70d4442 f25adba 70d4442 f25adba 70d4442 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import time
from dataclasses import dataclass
from typing import Dict, Any, Optional, Callable
from .backends import LLMBackend, make_backend
from .config import ModelSpec
from .tasks import TaskContext
@dataclass
class AgentResult:
role: str
model_id: str
backend: str
prompt: str
system: Optional[str]
output: str
elapsed_s: float
params: Dict[str, Any]
class BaseAgent:
role: str = "agent"
def __init__(self, spec: ModelSpec):
self.spec = spec
self.backend: LLMBackend = make_backend(spec.backend, spec.model_id)
def run(self, prompt: str) -> AgentResult:
t0 = time.time()
params = {
"temperature": self.spec.temperature,
"max_new_tokens": self.spec.max_new_tokens,
"top_p": self.spec.top_p,
"repetition_penalty": self.spec.repetition_penalty,
**(self.spec.extra or {}),
}
out = self.backend.generate(prompt, system=self.spec.system_prompt, params=params)
return AgentResult(
role=self.role,
model_id=self.spec.model_id,
backend=self.spec.backend,
prompt=prompt,
system=self.spec.system_prompt,
output=out,
elapsed_s=time.time() - t0,
params=params,
)
class AnalyzerAgent(BaseAgent):
role = "analyzer"
class RefactorAgent(BaseAgent):
role = "refactor"
class CriticAgent(BaseAgent):
role = "critic" |