File size: 6,478 Bytes
46a8a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920e16a
 
 
 
 
 
 
 
 
 
 
46a8a46
920e16a
 
 
 
 
 
 
 
 
 
 
46a8a46
920e16a
 
 
 
 
 
 
 
 
 
 
46a8a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""LangGraph utilities"""

import os
from typing import Dict, Any, Optional, List
from dotenv import load_dotenv
import json
import json_repair

from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic  
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import HumanMessage, SystemMessage
from langchain_community.callbacks.manager import get_openai_callback
from tenacity import retry, stop_after_attempt, wait_exponential

from src.state.poster_state import ModelConfig

load_dotenv()


def create_model(config: ModelConfig):
    """create chat model from config"""
    if config.provider == 'openai':
        openai_kwargs = {
            'model_name': config.model_name,
            'temperature': config.temperature,
            'max_tokens': config.max_tokens,
            'api_key': os.getenv('OPENAI_API_KEY')
        }
        base_url = os.getenv('OPENAI_BASE_URL')
        if base_url:
            openai_kwargs['base_url'] = base_url
            
        return ChatOpenAI(**openai_kwargs)
    elif config.provider == 'anthropic':
        anthropic_kwargs = {
            'model': config.model_name,
            'temperature': config.temperature,
            'max_tokens': config.max_tokens,
            'api_key': os.getenv('ANTHROPIC_API_KEY')
        }
        base_url = os.getenv('ANTHROPIC_BASE_URL')
        if base_url:
            anthropic_kwargs['base_url'] = base_url
            
        return ChatAnthropic(**anthropic_kwargs)
    elif config.provider == 'google':
        google_kwargs = {
            'model': config.model_name,
            'temperature': config.temperature,
            'max_output_tokens': config.max_tokens,
            'google_api_key': os.getenv('GOOGLE_API_KEY')
        }
        base_url = os.getenv('GOOGLE_BASE_URL')
        if base_url:
            google_kwargs['base_url'] = base_url
            
        return ChatGoogleGenerativeAI(**google_kwargs)
    else:
        raise ValueError(f"unsupported provider: {config.provider}")


class LangGraphAgent:
    """langgraph agent wrapper"""
    
    def __init__(self, system_msg: str, config: ModelConfig):
        self.system_msg = system_msg
        self.config = config
        self.model = create_model(config)
        self.history = [SystemMessage(content=system_msg)]
    
    def reset(self):
        """reset conversation"""
        self.history = [SystemMessage(content=self.system_msg)]
    
    @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
    def step(self, message: str) -> 'AgentResponse':
        """process message and return response"""
        # check if message is json with image data
        try:
            msg_data = json.loads(message)
            if isinstance(msg_data, list) and any("image_url" in item for item in msg_data):
                # vision model call
                return self._step_vision(msg_data)
        except:
            pass
        
        # regular text call
        self.history.append(HumanMessage(content=message))
        
        # keep conversation window
        if len(self.history) > 10:
            self.history = [self.history[0]] + self.history[-9:]
        
        # get response with token tracking
        input_tokens, output_tokens = 0, 0
        try:
            if self.config.provider == 'openai':
                with get_openai_callback() as cb:
                    response = self.model.invoke(self.history)
                    input_tokens = cb.prompt_tokens or 0
                    output_tokens = cb.completion_tokens or 0
            else:
                response = self.model.invoke(self.history)
                # estimate tokens for non-openai
                input_tokens = len(message.split()) * 1.3
                output_tokens = len(response.content.split()) * 1.3
        except Exception as e:
            print(f"model call failed: {e}")
            input_tokens = len(message.split()) * 1.3
            output_tokens = 100
            raise
        
        self.history.append(response)
        
        return AgentResponse(response.content, input_tokens, output_tokens)
    
    def _step_vision(self, messages: List[Dict]) -> 'AgentResponse':
        """handle vision model calls"""
        # convert to proper format
        content = []
        for msg in messages:
            if msg.get("type") == "text":
                content.append({"type": "text", "text": msg["text"]})
            elif msg.get("type") == "image_url":
                content.append({
                    "type": "image_url",
                    "image_url": msg["image_url"]
                })
        
        human_msg = HumanMessage(content=content)
        
        # get response
        input_tokens, output_tokens = 0, 0
        try:
            if self.config.provider == 'openai':
                with get_openai_callback() as cb:
                    response = self.model.invoke([self.history[0], human_msg])
                    input_tokens = cb.prompt_tokens or 0
                    output_tokens = cb.completion_tokens or 0
            else:
                response = self.model.invoke([self.history[0], human_msg])
                # estimate tokens
                input_tokens = 200  # rough estimate for image
                output_tokens = len(response.content.split()) * 1.3
        except Exception as e:
            print(f"vision model call failed: {e}")
            raise
        
        return AgentResponse(response.content, input_tokens, output_tokens)


class AgentResponse:
    """agent response with token tracking"""
    def __init__(self, content: str, input_tokens: int, output_tokens: int):
        self.content = content
        self.input_tokens = input_tokens
        self.output_tokens = output_tokens


def extract_json(response: str) -> Dict[str, Any]:
    """extract json from model response"""
    
    # find json code block
    start = response.find("```json")
    end = response.rfind("```")
    
    if start != -1 and end != -1 and end > start:
        json_content = response[start + 7:end].strip()
    else:
        json_content = response.strip()
    
    try:
        return json_repair.loads(json_content)
    except Exception as e:
        raise ValueError(f"failed to parse json: {e}")


def load_prompt(path: str) -> str:
    """load prompt template from file"""
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()