Spaces:
Runtime error
Runtime error
| import asyncio | |
| import base64 | |
| import os | |
| import re | |
| from dataclasses import asdict, dataclass | |
| from math import ceil | |
| import jsonlines | |
| import requests | |
| import tiktoken | |
| import yaml | |
| from FlagEmbedding import BGEM3FlagModel | |
| from jinja2 import Environment, Template | |
| from oaib import Auto | |
| from openai import OpenAI | |
| from PIL import Image | |
| from torch import Tensor, cosine_similarity | |
| from src.model_utils import get_text_embedding | |
| from src.utils import get_json_from_response, pexists, pjoin, print, tenacity | |
| ENCODING = tiktoken.encoding_for_model("gpt-4o") | |
| def run_async(coroutine): | |
| """ | |
| Run an asynchronous coroutine in a non-async environment. | |
| Args: | |
| coroutine: The coroutine to run. | |
| Returns: | |
| The result of the coroutine. | |
| """ | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| job = loop.run_until_complete(coroutine) | |
| return job | |
| def calc_image_tokens(images: list[str]): | |
| """ | |
| Calculate the number of tokens for a list of images. | |
| """ | |
| tokens = 0 | |
| for image in images: | |
| with open(image, "rb") as f: | |
| width, height = Image.open(f).size | |
| if width > 1024 or height > 1024: | |
| if width > height: | |
| height = int(height * 1024 / width) | |
| width = 1024 | |
| else: | |
| width = int(width * 1024 / height) | |
| height = 1024 | |
| h = ceil(height / 512) | |
| w = ceil(width / 512) | |
| tokens += 85 + 170 * h * w | |
| return tokens | |
| class LLM: | |
| """ | |
| A wrapper class to interact with a language model. | |
| """ | |
| def __init__( | |
| self, | |
| model: str = "gpt-4o-2024-08-06", | |
| api_base: str = None, | |
| use_openai: bool = True, | |
| use_batch: bool = False, | |
| ) -> None: | |
| """ | |
| Initialize the LLM. | |
| Args: | |
| model (str): The model name. | |
| api_base (str): The base URL for the API. | |
| use_openai (bool): Whether to use OpenAI. | |
| use_batch (bool): Whether to use OpenAI's Batch API, which is single thread only. | |
| """ | |
| if use_openai and "OPENAI_API_KEY" in os.environ: | |
| self.client = OpenAI(base_url=api_base) | |
| if use_batch and "OPENAI_API_KEY" in os.environ: | |
| assert use_openai, "use_batch must be used with use_openai" | |
| self.oai_batch = Auto(loglevel=0) | |
| if "OPENAI_API_KEY" not in os.environ: | |
| print("Warning: no API key found") | |
| self.model = model | |
| self.api_base = api_base | |
| self._use_openai = use_openai | |
| self._use_batch = use_batch | |
| def __call__( | |
| self, | |
| content: str, | |
| images: list[str] = None, | |
| system_message: str = None, | |
| history: list = None, | |
| delay_batch: bool = False, | |
| return_json: bool = False, | |
| return_message: bool = False, | |
| ) -> str | dict | list: | |
| """ | |
| Call the language model with a prompt and optional images. | |
| Args: | |
| content (str): The prompt content. | |
| images (list[str]): A list of image file paths. | |
| system_message (str): The system message. | |
| history (list): The conversation history. | |
| delay_batch (bool): Whether to delay return of response. | |
| return_json (bool): Whether to return the response as JSON. | |
| return_message (bool): Whether to return the message. | |
| Returns: | |
| str | dict | list: The response from the model. | |
| """ | |
| if content.startswith("You are"): | |
| system_message, content = content.split("\n", 1) | |
| if history is None: | |
| history = [] | |
| if isinstance(images, str): | |
| images = [images] | |
| system, message = self.format_message(content, images, system_message) | |
| if self._use_batch: | |
| result = run_async(self._run_batch(system + history + message, delay_batch)) | |
| if delay_batch: | |
| return | |
| try: | |
| response = result.to_dict()["result"][0]["choices"][0]["message"][ | |
| "content" | |
| ] | |
| except Exception as e: | |
| print("Failed to get response from batch") | |
| raise e | |
| elif self._use_openai: | |
| completion = self.client.chat.completions.create( | |
| model=self.model, messages=system + history + message | |
| ) | |
| response = completion.choices[0].message.content | |
| else: | |
| response = requests.post( | |
| self.api_base, | |
| json={ | |
| "system": system_message, | |
| "prompt": content, | |
| "image": [ | |
| i["image_url"]["url"] | |
| for i in message[-1]["content"] | |
| if i["type"] == "image_url" | |
| ], | |
| }, | |
| ) | |
| response.raise_for_status() | |
| response = response.text | |
| message.append({"role": "assistant", "content": response}) | |
| if return_json: | |
| response = get_json_from_response(response) | |
| if return_message: | |
| response = (response, message) | |
| return response | |
| def __repr__(self) -> str: | |
| return f"LLM(model={self.model}, api_base={self.api_base})" | |
| async def _run_batch(self, messages: list, delay_batch: bool = False): | |
| await self.oai_batch.add( | |
| "chat.completions.create", | |
| model=self.model, | |
| messages=messages, | |
| ) | |
| if delay_batch: | |
| return | |
| return await self.oai_batch.run() | |
| def format_message( | |
| self, | |
| content: str, | |
| images: list[str] = None, | |
| system_message: str = None, | |
| ): | |
| """ | |
| Message formatter for OpenAI server call. | |
| """ | |
| if system_message is None: | |
| system_message = "You are a helpful assistant" | |
| system = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": system_message}], | |
| } | |
| ] | |
| message = [{"role": "user", "content": [{"type": "text", "text": content}]}] | |
| if images is not None: | |
| if not isinstance(images, list): | |
| images = [images] | |
| for image in images: | |
| with open(image, "rb") as f: | |
| message[0]["content"].append( | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64.b64encode(f.read()).decode('utf-8')}" | |
| }, | |
| } | |
| ) | |
| return system, message | |
| def get_batch_result(self): | |
| """ | |
| Get responses from delayed batch calls. | |
| """ | |
| results = run_async(self.oai_batch.run()) | |
| return [ | |
| r["choices"][0]["message"]["content"] | |
| for r in results.to_dict()["result"].values() | |
| ] | |
| def clear_history(self): | |
| self.history = [] | |
| class Turn: | |
| """ | |
| A class to represent a turn in a conversation. | |
| """ | |
| id: int | |
| prompt: str | |
| response: str | |
| message: list | |
| images: list[str] = None | |
| input_tokens: int = 0 | |
| output_tokens: int = 0 | |
| embedding: Tensor = None | |
| def to_dict(self): | |
| return {k: v for k, v in asdict(self).items() if k != "embedding"} | |
| def calc_token(self): | |
| """ | |
| Calculate the number of tokens for the turn. | |
| """ | |
| if self.images is not None: | |
| self.input_tokens += calc_image_tokens(self.images) | |
| self.input_tokens += len(ENCODING.encode(self.prompt)) | |
| self.output_tokens = len(ENCODING.encode(self.response)) | |
| def __eq__(self, other): | |
| return self is other | |
| class Role: | |
| """ | |
| An agent, defined by its instruction template and model. | |
| """ | |
| def __init__( | |
| self, | |
| name: str, | |
| env: Environment, | |
| record_cost: bool, | |
| llm: LLM = None, | |
| config: dict = None, | |
| text_model: BGEM3FlagModel = None, | |
| ): | |
| """ | |
| Initialize the Agent. | |
| Args: | |
| name (str): The name of the role. | |
| env (Environment): The Jinja2 environment. | |
| record_cost (bool): Whether to record the token cost. | |
| llm (LLM): The language model. | |
| config (dict): The configuration. | |
| text_model (BGEM3FlagModel): The text model. | |
| """ | |
| self.name = name | |
| if config is None: | |
| with open(f"roles/{name}.yaml", "r") as f: | |
| config = yaml.safe_load(f) | |
| if llm is None: | |
| llm = globals()[config["use_model"] + "_model"] | |
| self.llm = llm | |
| self.model = llm.model | |
| self.record_cost = record_cost | |
| self.text_model = text_model | |
| self.return_json = config["return_json"] | |
| self.system_message = config["system_prompt"] | |
| self.prompt_args = set(config["jinja_args"]) | |
| self.template = env.from_string(config["template"]) | |
| self.retry_template = Template( | |
| """The previous output is invalid, please carefully analyze the traceback and feedback information, correct errors happened before. | |
| feedback: | |
| {{feedback}} | |
| traceback: | |
| {{traceback}} | |
| Give your corrected output in the same format without including the previous output: | |
| """ | |
| ) | |
| self.system_tokens = len(ENCODING.encode(self.system_message)) | |
| self.input_tokens = 0 | |
| self.output_tokens = 0 | |
| self.history: list[Turn] = [] | |
| def calc_cost(self, turns: list[Turn]): | |
| """ | |
| Calculate the cost of a list of turns. | |
| """ | |
| for turn in turns: | |
| self.input_tokens += turn.input_tokens | |
| self.output_tokens += turn.output_tokens | |
| self.input_tokens += self.system_tokens | |
| self.output_tokens += 3 | |
| def get_history(self, similar: int, recent: int, prompt: str): | |
| """ | |
| Get the conversation history. | |
| """ | |
| history = self.history[-recent:] if recent > 0 else [] | |
| if similar > 0: | |
| embedding = get_text_embedding(prompt, self.text_model) | |
| history.sort(key=lambda x: cosine_similarity(embedding, x.embedding)) | |
| for turn in history: | |
| if len(history) > similar + recent: | |
| break | |
| if turn not in history: | |
| history.append(turn) | |
| history.sort(key=lambda x: x.id) | |
| return history | |
| def save_history(self, output_dir: str): | |
| """ | |
| Save the conversation history to a file. | |
| """ | |
| history_file = pjoin(output_dir, f"{self.name}.jsonl") | |
| if pexists(history_file) and len(self.history) == 0: | |
| return | |
| with jsonlines.open(history_file, "w") as writer: | |
| writer.write( | |
| { | |
| "input_tokens": self.input_tokens, | |
| "output_tokens": self.output_tokens, | |
| } | |
| ) | |
| for turn in self.history: | |
| writer.write(turn.to_dict()) | |
| def retry(self, feedback: str, traceback: str, error_idx: int): | |
| """ | |
| Retry a failed turn with feedback and traceback. | |
| """ | |
| assert error_idx > 0, "error_idx must be greater than 0" | |
| prompt = self.retry_template.render(feedback=feedback, traceback=traceback) | |
| history = [] | |
| for turn in self.history[-error_idx:]: | |
| history.extend(turn.message) | |
| response, message = self.llm( | |
| prompt, | |
| history=history, | |
| return_message=True, | |
| ) | |
| turn = Turn( | |
| id=len(self.history), | |
| prompt=prompt, | |
| response=response, | |
| message=message, | |
| ) | |
| return self.__post_process__(response, self.history[-error_idx:], turn) | |
| def __repr__(self) -> str: | |
| return f"Role(name={self.name}, model={self.model})" | |
| def __call__( | |
| self, | |
| images: list[str] = None, | |
| recent: int = 0, | |
| similar: int = 0, | |
| **jinja_args, | |
| ): | |
| """ | |
| Call the agent with prompt arguments. | |
| Args: | |
| images (list[str]): A list of image file paths. | |
| recent (int): The number of recent turns to include. | |
| similar (int): The number of similar turns to include. | |
| **jinja_args: Additional arguments for the Jinja2 template. | |
| Returns: | |
| The response from the role. | |
| """ | |
| if isinstance(images, str): | |
| images = [images] | |
| assert self.prompt_args == set(jinja_args.keys()), "Invalid arguments" | |
| prompt = self.template.render(**jinja_args) | |
| history = self.get_history(similar, recent, prompt) | |
| history_msg = [] | |
| for turn in history: | |
| history_msg.extend(turn.message) | |
| response, message = self.llm( | |
| prompt, | |
| system_message=self.system_message, | |
| history=history_msg, | |
| images=images, | |
| return_message=True, | |
| ) | |
| turn = Turn( | |
| id=len(self.history), | |
| prompt=prompt, | |
| response=response, | |
| message=message, | |
| images=images, | |
| ) | |
| return self.__post_process__(response, history, turn, similar) | |
| def __post_process__( | |
| self, response: str, history: list[Turn], turn: Turn, similar: int = 0 | |
| ): | |
| """ | |
| Post-process the response from the agent. | |
| """ | |
| self.history.append(turn) | |
| if similar > 0: | |
| turn.embedding = get_text_embedding(turn.prompt, self.text_model) | |
| if self.record_cost: | |
| turn.calc_token() | |
| self.calc_cost(history + [turn]) | |
| if self.return_json: | |
| response = get_json_from_response(response) | |
| return response | |
| def get_simple_modelname(llms: list[LLM]): | |
| """ | |
| Get a abbreviation from a list of LLMs. | |
| """ | |
| if isinstance(llms, LLM): | |
| llms = [llms] | |
| return "+".join(re.search(r"^(.*?)-\d{2}", llm.model).group(1) for llm in llms) | |
| gpt4o = LLM(model="gpt-4o-2024-08-06", use_batch=True) | |
| gpt4omini = LLM(model="gpt-4o-mini-2024-07-18", use_batch=True) | |
| qwen2_5 = LLM( | |
| model="Qwen2.5-72B-Instruct-GPTQ-Int4", api_base="http://124.16.138.143:7812/v1" | |
| ) | |
| qwen_vl = LLM(model="Qwen2-VL-72B-Instruct", api_base="http://124.16.138.144:7999/v1") | |
| qwen_coder = LLM( | |
| model="Qwen2.5-Coder-32B-Instruct", api_base="http://127.0.0.1:8008/v1" | |
| ) | |
| intern_vl = LLM(model="InternVL2_5-78B", api_base="http://124.16.138.144:8009/v1") | |
| language_model = gpt4o | |
| vision_model = gpt4o | |
| if __name__ == "__main__": | |
| gpt4o = LLM(model="gpt-4o-2024-08-06") | |
| print( | |
| gpt4o( | |
| "who r u", | |
| ) | |
| ) | |