File size: 1,794 Bytes
5a2d62e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json
from typing import Dict, Generator, AsyncGenerator
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.callbacks import StreamingStdOutCallbackHandler
from config import Config
from logger import Logger
from file_manager import FileManager
from models import AgentState
from langchain_ollama import OllamaLLM
from langchain_core.runnables import RunnableConfig
from langchain_core.output_parsers import StrOutputParser

class BaseAgent:
    def __init__(self):
        self.config = Config()
        self.logger = Logger()
        self.file_manager = FileManager()
        
        # Set up callbacks based on streaming config
        callbacks = [StreamingStdOutCallbackHandler()] if self.config.streaming else None
        
        self.llm = OllamaLLM(
            model=self.config.ollama_model,
            base_url=self.config.ollama_base_url,
            temperature=self.config.temperature,
            top_p=self.config.top_p,
            callbacks=callbacks,
        
        )
        
    async def _stream_process(self, state: AgentState, prompt_template: str, output_key: str, step_name: str, **kwargs) -> AsyncGenerator[str, None]:
        prompt = ChatPromptTemplate.from_template(prompt_template)
        chain = prompt | self.llm | StrOutputParser()

        # Use stream method to get partial chunks
        async for chunk in chain.astream(kwargs):
            yield json.dumps({output_key: chunk})
    
    async def _process(self, state: AgentState, prompt_template: str, 
                output_key: str, step_name: str, **kwargs) -> Dict:
        prompt = ChatPromptTemplate.from_template(prompt_template)
        chain = prompt | self.llm
        
        result = await chain.ainvoke(kwargs)
     
        return {output_key: result}