|
|
import json |
|
|
from typing import Dict, Generator, AsyncGenerator |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.callbacks import StreamingStdOutCallbackHandler |
|
|
from config import Config |
|
|
from logger import Logger |
|
|
from file_manager import FileManager |
|
|
from models import AgentState |
|
|
from langchain_ollama import OllamaLLM |
|
|
from langchain_core.runnables import RunnableConfig |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
|
|
|
class BaseAgent: |
|
|
def __init__(self): |
|
|
self.config = Config() |
|
|
self.logger = Logger() |
|
|
self.file_manager = FileManager() |
|
|
|
|
|
|
|
|
callbacks = [StreamingStdOutCallbackHandler()] if self.config.streaming else None |
|
|
|
|
|
self.llm = OllamaLLM( |
|
|
model=self.config.ollama_model, |
|
|
base_url=self.config.ollama_base_url, |
|
|
temperature=self.config.temperature, |
|
|
top_p=self.config.top_p, |
|
|
callbacks=callbacks, |
|
|
|
|
|
) |
|
|
|
|
|
async def _stream_process(self, state: AgentState, prompt_template: str, output_key: str, step_name: str, **kwargs) -> AsyncGenerator[str, None]: |
|
|
prompt = ChatPromptTemplate.from_template(prompt_template) |
|
|
chain = prompt | self.llm | StrOutputParser() |
|
|
|
|
|
|
|
|
async for chunk in chain.astream(kwargs): |
|
|
yield json.dumps({output_key: chunk}) |
|
|
|
|
|
async def _process(self, state: AgentState, prompt_template: str, |
|
|
output_key: str, step_name: str, **kwargs) -> Dict: |
|
|
prompt = ChatPromptTemplate.from_template(prompt_template) |
|
|
chain = prompt | self.llm |
|
|
|
|
|
result = await chain.ainvoke(kwargs) |
|
|
|
|
|
return {output_key: result} |