File size: 7,786 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import asyncio
import requests
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion
from typing import Optional, List
from litellm import token_counter
from ..core.registry import register_model
from .model_configs import OpenRouterConfig
from .base_model import BaseLLM
from .model_utils import Cost, cost_manager
@register_model(config_cls=OpenRouterConfig, alias=["openrouter"])
class OpenRouterLLM(BaseLLM):
def init_model(self):
config: OpenRouterConfig = self.config
self._client = self._init_client(config)
self._default_ignore_fields = ["llm_type", "openrouter_key", "openrouter_base", "openrouter_model_base", "output_response"]
def _init_client(self, config: OpenRouterConfig):
client = OpenAI(
api_key=config.openrouter_key,
base_url=config.openrouter_base
)
return client
def formulate_messages(self, prompts: List[str], system_messages: Optional[List[str]] = None) -> List[List[dict]]:
if system_messages:
assert len(prompts) == len(system_messages), f"the number of prompts ({len(prompts)}) is different from the number of system_messages ({len(system_messages)})"
else:
system_messages = [None] * len(prompts)
messages_list = []
for prompt, system_message in zip(prompts, system_messages):
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
messages.append({"role": "user", "content": prompt})
messages_list.append(messages)
return messages_list
def update_completion_params(self, params1: dict, params2: dict) -> dict:
config_params: list = self.config.get_config_params()
for key, value in params2.items():
if key in self._default_ignore_fields:
continue
if key not in config_params:
continue
params1[key] = value
return params1
def get_completion_params(self, **kwargs):
completion_params = self.config.get_set_params(ignore=self._default_ignore_fields)
completion_params = self.update_completion_params(completion_params, kwargs)
return completion_params
def get_stream_output(self, response: Stream, output_response: bool=True) -> str:
output = ""
for chunk in response:
content = chunk.choices[0].delta.content
if content:
if output_response:
print(content, end="", flush=True)
output += content
if output_response:
print("")
return output
async def get_stream_output_async(self, response, output_response: bool = False) -> str:
output = ""
async for chunk in response:
content = chunk.choices[0].delta.content
if content:
if output_response:
print(content, end="", flush=True)
output += content
if output_response:
print("")
return output
def get_completion_output(self, response: ChatCompletion, output_response: bool=True) -> str:
output = response.choices[0].message.content
if output_response:
print(output)
return output
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5))
def single_generate(self, messages: List[dict], **kwargs) -> str:
stream = kwargs.get("stream", self.config.stream)
output_response = kwargs.get("output_response", self.config.output_response)
try:
completion_params = self.get_completion_params(**kwargs)
response = self._client.chat.completions.create(messages=messages, **completion_params)
if stream:
output = self.get_stream_output(response, output_response=output_response)
cost = self._stream_cost(messages=messages, output=output)
else:
output: str = self.get_completion_output(response=response, output_response=output_response)
cost = self._completion_cost(response)
self._update_cost(cost=cost)
except Exception as e:
raise RuntimeError(f"Error during single_generate of OpenRouterLLM: {str(e)}")
return output
def batch_generate(self, batch_messages: List[List[dict]], **kwargs) -> List[str]:
return [self.single_generate(messages=one_messages, **kwargs) for one_messages in batch_messages]
async def single_generate_async(self, messages: List[dict], **kwargs) -> str:
stream = kwargs.get("stream", self.config.stream)
output_response = kwargs.get("output_response", self.config.output_response)
try:
isolated_client = self._init_client(self.config)
completion_params = self.get_completion_params(**kwargs)
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: isolated_client.chat.completions.create(
messages=messages,
**completion_params
)
)
if stream:
if hasattr(response, "__aiter__"):
output = await self.get_stream_output_async(response, output_response=output_response)
else:
output = self.get_stream_output(response, output_response=output_response)
cost = self._stream_cost(messages=messages, output=output)
else:
output: str = self.get_completion_output(response=response, output_response=output_response)
cost = self._completion_cost(response)
self._update_cost(cost=cost)
except Exception as e:
raise RuntimeError(f"Error during single_generate_async of OpenRouterLLM: {str(e)}")
return output
def _completion_cost(self, response: ChatCompletion) -> Cost:
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
return self._compute_cost(input_tokens=input_tokens, output_tokens=output_tokens)
def _stream_cost(self, messages: List[dict], output: str) -> Cost:
model: str = self.config.model
input_tokens = token_counter(model=model, messages=messages)
output_tokens = token_counter(model=model, text=output)
return self._compute_cost(input_tokens=input_tokens, output_tokens=output_tokens)
def _compute_cost(self, input_tokens: int, output_tokens: int) -> Cost:
input_cost_per_token, output_cost_per_token = self._get_cost()
input_cost = input_tokens * input_cost_per_token
output_cost = output_tokens * output_cost_per_token
cost = Cost(input_tokens=input_tokens, output_tokens=output_tokens, input_cost=input_cost, output_cost=output_cost)
return cost
def _update_cost(self, cost: Cost):
cost_manager.update_cost(cost=cost, model=self.config.model)
def _get_cost(self):
url = self.config.openrouter_model_base
response = requests.get(url)
data = response.json()
for model in data['data']:
if model['id'] == self.config.model:
pricing = model.get('pricing',{})
input_cost = float(pricing.get('prompt', 0))
output_cost = float(pricing.get('completion', 0))
return input_cost, output_cost
return 0, 0 |