AlsuGibadullina's picture
Update src/agents.py
70d4442 verified
import time
from dataclasses import dataclass
from typing import Dict, Any, Optional, Callable
from .backends import LLMBackend, make_backend
from .config import ModelSpec
from .tasks import TaskContext
@dataclass
class AgentResult:
role: str
model_id: str
backend: str
prompt: str
system: Optional[str]
output: str
elapsed_s: float
params: Dict[str, Any]
class BaseAgent:
role: str = "agent"
def __init__(self, spec: ModelSpec):
self.spec = spec
self.backend: LLMBackend = make_backend(spec.backend, spec.model_id)
def run(self, prompt: str) -> AgentResult:
t0 = time.time()
params = {
"temperature": self.spec.temperature,
"max_new_tokens": self.spec.max_new_tokens,
"top_p": self.spec.top_p,
"repetition_penalty": self.spec.repetition_penalty,
**(self.spec.extra or {}),
}
out = self.backend.generate(prompt, system=self.spec.system_prompt, params=params)
return AgentResult(
role=self.role,
model_id=self.spec.model_id,
backend=self.spec.backend,
prompt=prompt,
system=self.spec.system_prompt,
output=out,
elapsed_s=time.time() - t0,
params=params,
)
class AnalyzerAgent(BaseAgent):
role = "analyzer"
class RefactorAgent(BaseAgent):
role = "refactor"
class CriticAgent(BaseAgent):
role = "critic"