File size: 1,794 Bytes
5a2d62e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import json
from typing import Dict, Generator, AsyncGenerator
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.callbacks import StreamingStdOutCallbackHandler
from config import Config
from logger import Logger
from file_manager import FileManager
from models import AgentState
from langchain_ollama import OllamaLLM
from langchain_core.runnables import RunnableConfig
from langchain_core.output_parsers import StrOutputParser
class BaseAgent:
def __init__(self):
self.config = Config()
self.logger = Logger()
self.file_manager = FileManager()
# Set up callbacks based on streaming config
callbacks = [StreamingStdOutCallbackHandler()] if self.config.streaming else None
self.llm = OllamaLLM(
model=self.config.ollama_model,
base_url=self.config.ollama_base_url,
temperature=self.config.temperature,
top_p=self.config.top_p,
callbacks=callbacks,
)
async def _stream_process(self, state: AgentState, prompt_template: str, output_key: str, step_name: str, **kwargs) -> AsyncGenerator[str, None]:
prompt = ChatPromptTemplate.from_template(prompt_template)
chain = prompt | self.llm | StrOutputParser()
# Use stream method to get partial chunks
async for chunk in chain.astream(kwargs):
yield json.dumps({output_key: chunk})
async def _process(self, state: AgentState, prompt_template: str,
output_key: str, step_name: str, **kwargs) -> Dict:
prompt = ChatPromptTemplate.from_template(prompt_template)
chain = prompt | self.llm
result = await chain.ainvoke(kwargs)
return {output_key: result} |