| import base64 |
| import json |
| from openai import AzureOpenAI |
| import os |
| import sys |
| sys.path.append('./rxn/') |
| import torch |
| import json |
| from getReaction import get_reaction |
|
|
|
|
|
|
| class RXNIM: |
|
|
| def __init__(self, api_version='2024-06-01', azure_endpoint='https://hkust.azure-api.net'): |
| |
| self.API_KEY = os.environ.get('key') |
| if not self.API_KEY: |
| raise ValueError("Environment variable 'KEY' not set.") |
|
|
| |
| self.client = AzureOpenAI( |
| api_key=self.API_KEY, |
| api_version=api_version, |
| azure_endpoint=azure_endpoint, |
| ) |
|
|
| |
| self.tools = [ |
| { |
| 'type': 'function', |
| 'function': { |
| 'name': 'get_reaction', |
| 'description': 'Get a list of reactions from a reaction image. A reaction contains data of the reactants, conditions, and products.', |
| 'parameters': { |
| 'type': 'object', |
| 'properties': { |
| 'image_path': { |
| 'type': 'string', |
| 'description': 'The path to the reaction image.', |
| }, |
| }, |
| 'required': ['image_path'], |
| 'additionalProperties': False, |
| }, |
| }, |
| }, |
| ] |
|
|
| |
| self.TOOL_MAP = { |
| 'get_reaction': get_reaction, |
| } |
|
|
| def encode_image(self, image_path: str): |
| '''Returns a base64 string of the input image.''' |
| with open(image_path, "rb") as image_file: |
| return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
| def process(self, image_path: str, prompt_path: str): |
| |
| base64_image = self.encode_image(image_path) |
|
|
| |
| with open(prompt_path, 'r') as prompt_file: |
| prompt = prompt_file.read() |
|
|
| |
| messages = [ |
| {'role': 'system', 'content': 'You are a helpful assistant. Before providing the final answer, consider if any additional information or tool usage is needed to improve your response.'}, |
| { |
| 'role': 'user', |
| 'content': [ |
| { |
| 'type': 'text', |
| 'text': prompt |
| }, |
| { |
| 'type': 'image_url', |
| 'image_url': { |
| 'url': f'data:image/png;base64,{base64_image}' |
| } |
| } |
| ] |
| }, |
| ] |
|
|
| MAX_ITERATIONS = 5 |
| iterations = 0 |
|
|
| while iterations < MAX_ITERATIONS: |
| iterations += 1 |
| print(f'Iteration {iterations}') |
|
|
| |
| response = self.client.chat.completions.create( |
| model='gpt-4o', |
| temperature=0, |
| response_format={'type': 'json_object'}, |
| messages=messages, |
| tools=self.tools, |
| ) |
|
|
| |
| assistant_message = response.choices[0].message |
|
|
| |
| messages.append(assistant_message) |
|
|
| |
| if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls: |
| tool_calls = assistant_message.tool_calls |
| results = [] |
|
|
| for tool_call in tool_calls: |
| tool_name = tool_call.function.name |
| tool_arguments = tool_call.function.arguments |
| tool_call_id = tool_call.id |
|
|
| tool_args = json.loads(tool_arguments) |
|
|
| if tool_name in self.TOOL_MAP: |
| try: |
| |
| tool_result = self.TOOL_MAP[tool_name](image_path) |
| print(f'{tool_name} result: {tool_result}') |
| except Exception as e: |
| tool_result = {'error': str(e)} |
| else: |
| tool_result = {'error': f"Unknown tool called: {tool_name}"} |
|
|
| |
| results.append({ |
| 'role': 'tool', |
| 'content': json.dumps({ |
| 'image_path': image_path, |
| f'{tool_name}': tool_result, |
| }), |
| 'tool_call_id': tool_call_id, |
| }) |
| print(results) |
|
|
| |
| messages.extend(results) |
| else: |
| |
| break |
|
|
| else: |
| |
| return "The assistant could not complete the task within the maximum number of iterations." |
|
|
| |
| final_content = assistant_message.content |
| return final_content |
|
|
| |
|
|
|
|