Spaces:
Running
Running
File size: 6,478 Bytes
46a8a46 920e16a 46a8a46 920e16a 46a8a46 920e16a 46a8a46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""LangGraph utilities"""
import os
from typing import Dict, Any, Optional, List
from dotenv import load_dotenv
import json
import json_repair
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import HumanMessage, SystemMessage
from langchain_community.callbacks.manager import get_openai_callback
from tenacity import retry, stop_after_attempt, wait_exponential
from src.state.poster_state import ModelConfig
load_dotenv()
def create_model(config: ModelConfig):
"""create chat model from config"""
if config.provider == 'openai':
openai_kwargs = {
'model_name': config.model_name,
'temperature': config.temperature,
'max_tokens': config.max_tokens,
'api_key': os.getenv('OPENAI_API_KEY')
}
base_url = os.getenv('OPENAI_BASE_URL')
if base_url:
openai_kwargs['base_url'] = base_url
return ChatOpenAI(**openai_kwargs)
elif config.provider == 'anthropic':
anthropic_kwargs = {
'model': config.model_name,
'temperature': config.temperature,
'max_tokens': config.max_tokens,
'api_key': os.getenv('ANTHROPIC_API_KEY')
}
base_url = os.getenv('ANTHROPIC_BASE_URL')
if base_url:
anthropic_kwargs['base_url'] = base_url
return ChatAnthropic(**anthropic_kwargs)
elif config.provider == 'google':
google_kwargs = {
'model': config.model_name,
'temperature': config.temperature,
'max_output_tokens': config.max_tokens,
'google_api_key': os.getenv('GOOGLE_API_KEY')
}
base_url = os.getenv('GOOGLE_BASE_URL')
if base_url:
google_kwargs['base_url'] = base_url
return ChatGoogleGenerativeAI(**google_kwargs)
else:
raise ValueError(f"unsupported provider: {config.provider}")
class LangGraphAgent:
"""langgraph agent wrapper"""
def __init__(self, system_msg: str, config: ModelConfig):
self.system_msg = system_msg
self.config = config
self.model = create_model(config)
self.history = [SystemMessage(content=system_msg)]
def reset(self):
"""reset conversation"""
self.history = [SystemMessage(content=self.system_msg)]
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def step(self, message: str) -> 'AgentResponse':
"""process message and return response"""
# check if message is json with image data
try:
msg_data = json.loads(message)
if isinstance(msg_data, list) and any("image_url" in item for item in msg_data):
# vision model call
return self._step_vision(msg_data)
except:
pass
# regular text call
self.history.append(HumanMessage(content=message))
# keep conversation window
if len(self.history) > 10:
self.history = [self.history[0]] + self.history[-9:]
# get response with token tracking
input_tokens, output_tokens = 0, 0
try:
if self.config.provider == 'openai':
with get_openai_callback() as cb:
response = self.model.invoke(self.history)
input_tokens = cb.prompt_tokens or 0
output_tokens = cb.completion_tokens or 0
else:
response = self.model.invoke(self.history)
# estimate tokens for non-openai
input_tokens = len(message.split()) * 1.3
output_tokens = len(response.content.split()) * 1.3
except Exception as e:
print(f"model call failed: {e}")
input_tokens = len(message.split()) * 1.3
output_tokens = 100
raise
self.history.append(response)
return AgentResponse(response.content, input_tokens, output_tokens)
def _step_vision(self, messages: List[Dict]) -> 'AgentResponse':
"""handle vision model calls"""
# convert to proper format
content = []
for msg in messages:
if msg.get("type") == "text":
content.append({"type": "text", "text": msg["text"]})
elif msg.get("type") == "image_url":
content.append({
"type": "image_url",
"image_url": msg["image_url"]
})
human_msg = HumanMessage(content=content)
# get response
input_tokens, output_tokens = 0, 0
try:
if self.config.provider == 'openai':
with get_openai_callback() as cb:
response = self.model.invoke([self.history[0], human_msg])
input_tokens = cb.prompt_tokens or 0
output_tokens = cb.completion_tokens or 0
else:
response = self.model.invoke([self.history[0], human_msg])
# estimate tokens
input_tokens = 200 # rough estimate for image
output_tokens = len(response.content.split()) * 1.3
except Exception as e:
print(f"vision model call failed: {e}")
raise
return AgentResponse(response.content, input_tokens, output_tokens)
class AgentResponse:
"""agent response with token tracking"""
def __init__(self, content: str, input_tokens: int, output_tokens: int):
self.content = content
self.input_tokens = input_tokens
self.output_tokens = output_tokens
def extract_json(response: str) -> Dict[str, Any]:
"""extract json from model response"""
# find json code block
start = response.find("```json")
end = response.rfind("```")
if start != -1 and end != -1 and end > start:
json_content = response[start + 7:end].strip()
else:
json_content = response.strip()
try:
return json_repair.loads(json_content)
except Exception as e:
raise ValueError(f"failed to parse json: {e}")
def load_prompt(path: str) -> str:
"""load prompt template from file"""
with open(path, 'r', encoding='utf-8') as f:
return f.read() |