Spaces:
Sleeping
Sleeping
| from openai import OpenAI | |
| import json_repair | |
| from transformers import AutoTokenizer | |
| from prompts import * | |
| import re | |
| from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_exception_type | |
| from openai import RateLimitError | |
| from difflib import get_close_matches | |
| class ChatbotSimulation: | |
| def __init__(self, app_name, app_description, site_map, relevant_tables_per_page, | |
| database, jinjia_prerender_page, task, solution, | |
| log_location, openai_api_key, agent='human', | |
| max_steps=30, max_tokens=8192, buffer_tokens=500): | |
| self.app_name = app_name | |
| self.app_description = app_description | |
| self.sitemap = site_map | |
| self.relevant_tables_per_page = relevant_tables_per_page | |
| self.database = database | |
| self.jinjia_prerender_page = jinjia_prerender_page | |
| self.task = task | |
| self.solution = solution | |
| self.user_state = dict() | |
| self.user_state['current_page'] = self.sitemap['pages'][0]['id'] # Initialize current page | |
| self.user_state['task_completed'] = 'False' | |
| self.user_state['logged_in'] = 'False' | |
| self.user_state['back'] = 'False' | |
| self.log_location = log_location | |
| self.agent = agent.lower() | |
| if self.agent not in ['human', 'llm']: | |
| raise ValueError("Invalid agent type. Expected 'Human' or 'llm'.") | |
| self.max_steps = max_steps | |
| self.max_tokens = max_tokens | |
| self.buffer_tokens = buffer_tokens | |
| self.conversation = [] # Stores recent conversation snippets | |
| self.trajectory = [{"role": "system", "content": f"Welcome to {app_name} simulator! Your task is: {task}"}] | |
| self.prompt_count = 0 | |
| self.client = OpenAI(api_key=openai_api_key) | |
| self.actions = [] | |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True) | |
| #back button | |
| self.page_history = ['Home'] | |
| def _get_relevant_data(self, current_page): | |
| # Check if the current page exists as a key | |
| if current_page in self.relevant_tables_per_page: | |
| relevant_tables = self.relevant_tables_per_page[current_page] | |
| else: | |
| # Find the closest matching key | |
| closest_match = get_close_matches(current_page, self.relevant_tables_per_page.keys(), n=1, cutoff=0.5) | |
| if closest_match: | |
| relevant_tables = self.relevant_tables_per_page[closest_match[0]] | |
| else: | |
| return self.database | |
| return {table: self.database[table] for table in relevant_tables if table in self.database} | |
| def _get_prerender_page(self, current_page): | |
| if current_page in self.jinjia_prerender_page: | |
| return self.jinjia_prerender_page[current_page] | |
| else: | |
| closest_match = get_close_matches(current_page, self.jinjia_prerender_page.keys(), n=1, cutoff=0) | |
| return self.jinjia_prerender_page[closest_match[0]] | |
| def _generate_system_prompt(self): | |
| """Create a dynamic system prompt based on the current state.""" | |
| current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id'] | |
| last_page = self.page_history[-2] if len(self.page_history) > 1 else self.sitemap['pages'][0]['id'] | |
| relevant_database = self._get_relevant_data(current_page) | |
| relevant_sitemap = next((page for page in self.sitemap["pages"] if page["id"] == current_page), self.sitemap["pages"]) | |
| prerender_page = self._get_prerender_page(current_page) | |
| return get_system_prompt(app_name=self.app_name, | |
| app_description=self.app_description, | |
| relevant_database=relevant_database, | |
| user_state=self.user_state, | |
| task=self.task, | |
| current_page=current_page, | |
| last_page=last_page, | |
| actions=self.actions, | |
| sitemap_page=relevant_sitemap, | |
| jinjia_prerender=prerender_page, | |
| ) | |
| def _get_openai_response(self, prompt): | |
| """Fetch response from OpenAI API using tenacity for handling retries.""" | |
| self._trim_conversation() | |
| response = self.client.chat.completions.create( | |
| model="gpt-4", | |
| messages=prompt, | |
| max_tokens=self.buffer_tokens, # Adjusted max_tokens if needed | |
| temperature=0.7, | |
| ) | |
| return response.choices[0].message.content | |
| def _calculate_token_count(self, conversation): | |
| """Accurately calculate the token count in the conversation using a tokenizer.""" | |
| total_tokens = 0 | |
| for entry in conversation: | |
| # Tokenize each entry content and count tokens | |
| tokens = self.tokenizer.encode(entry['content'], truncation=False, add_special_tokens=False) | |
| total_tokens += len(tokens) | |
| return total_tokens | |
| def _trim_conversation(self): | |
| """Trim the conversation to keep it within the token limit.""" | |
| while self._calculate_token_count(self.conversation) >= (self.max_tokens - self.buffer_tokens * 2): | |
| self.conversation.pop(0) | |
| def one_conversation_round(self, user_input): | |
| """Conduct one round of conversation between the user and the assistant.""" | |
| # User provides input | |
| self.trajectory.append({"role": "user", "content": f'Human: {user_input}'}) | |
| valid_input = self._is_valid_input(user_input) | |
| if valid_input[0]: | |
| pass | |
| else: | |
| self.prompt_count += 1 | |
| invalid_input_message = f"\n{self.app_name}: Invalid input. {valid_input[1]}" | |
| self.trajectory.append({"role": "assistant", "content": invalid_input_message}) | |
| return invalid_input_message | |
| self.actions.append(user_input + f'on {self.user_state["current_page"]} page') | |
| self.conversation.append({"role": "user", "content": user_input}) | |
| self.prompt_count += 1 | |
| # Update user state using GPT's response | |
| current_page = self.page_history[-1] if len(self.page_history) >= 1 else self.sitemap['pages'][0]['id'] | |
| update_prompt = get_user_state_update_prompt(user_input=user_input, | |
| current_page=current_page, | |
| task=self.task, | |
| database=self.database, | |
| solution=self.solution, | |
| user_state=self.user_state, | |
| sitemap=self.sitemap) | |
| self.conversation.append({"role": "user", "content": update_prompt}) | |
| updated_state = self._get_openai_response(self.conversation).split("UPDATED", 1)[1].strip() | |
| self.conversation.pop(-1) # update prompt don't have to stay in conversation history | |
| # Parse and update the user state | |
| updated_state = json_repair.loads(updated_state) | |
| # format forcing of updated state | |
| required_keys = {'current_page', 'task_completed', 'back'} | |
| # Ensure `updated_state` is a dictionary | |
| while not isinstance(updated_state, dict): | |
| transform_prompt = f""" | |
| Transform {updated_state} to a properly formatted JSON file. | |
| Example Output Format: | |
| {{ | |
| 'current_page': 'Home', | |
| 'task_completed': False, | |
| 'back': False | |
| }} | |
| """ | |
| updated_state = self._get_openai_response([{"role": "system", "content": transform_prompt}]) | |
| updated_state = json_repair.loads(updated_state) | |
| # Manually add missing required keys | |
| for key in required_keys: | |
| if key not in updated_state: | |
| if key == 'current_page': | |
| updated_state[key] = self.page_history[-1] if len(self.page_history) >= 1 else "Home" | |
| else: | |
| updated_state[key] = False | |
| try: | |
| if str(updated_state['task_completed']).lower() == 'true': | |
| complete_message = f"{self.app_name}: Task completed! You took {self.prompt_count} steps." | |
| self.trajectory.append({"role": "assistant", "content": complete_message}) | |
| return complete_message | |
| except: | |
| updated_state['task_completed'] = 'False' | |
| self.user_state = updated_state | |
| if str(updated_state['back']).lower() == 'false': | |
| self.page_history.append(updated_state['current_page']) | |
| elif self.page_history: | |
| self.page_history.pop() | |
| ## no need to store old system prompt while we get a new one | |
| self.conversation = [entry for entry in self.conversation if entry["role"] != "system"] | |
| system_prompt = self._generate_system_prompt() | |
| # GPT generates the page instructions | |
| self.conversation.append({"role": "system", "content": system_prompt}) | |
| gpt_instruction = self._get_openai_response(self.conversation) | |
| self.conversation.append({"role": "assistant", "content": gpt_instruction}) | |
| self.trajectory.append({"role": "assistant", "content": gpt_instruction}) | |
| return gpt_instruction | |
| def start_conversation(self): | |
| greeting = f'\nWelcome to {self.app_name} simulator! Your task is: {self.task} \n' | |
| system_prompt = self._generate_system_prompt() | |
| # GPT generates the page instructions | |
| self.conversation.append({"role": "system", "content": system_prompt}) | |
| gpt_instruction = self._get_openai_response(self.conversation) | |
| self.conversation.append({"role": "assistant", "content": gpt_instruction}) | |
| return greeting + gpt_instruction | |
| def _extract_buttons(self): | |
| """Extract button numbers and their action types from the latest conversation if role is 'assistant'.""" | |
| # Get the last message | |
| last_message = self.conversation[-1] | |
| # Ensure the role of the last message is 'assistant' | |
| if last_message.get("role") != "assistant": | |
| return {} | |
| # Extract the content of the last message | |
| message_content = last_message.get("content", "") | |
| # Split the message content to isolate the button section | |
| options_split = re.split(r"you have the following options:", message_content, flags=re.IGNORECASE) | |
| # If the split doesn't produce at least two parts, return an empty dictionary | |
| if len(options_split) < 2: | |
| return {} | |
| # Extract button definitions from the second part of the split content | |
| button_section = options_split[1] | |
| pattern = r"(\d+)\.\s+(.*?):\s+([a-zA-Z_]+)" # Capture the number, button name, and action type | |
| buttons = re.findall(pattern, button_section) | |
| # Construct the dictionary with button numbers as keys and action types as values | |
| return {number: action_type.strip().lower() for number, _, action_type in buttons} | |
| def _is_valid_input(self, user_input): | |
| """Validate user input format.""" | |
| valid_buttons = self._extract_buttons() | |
| if valid_buttons == {}: | |
| return [True, "Enter Anything is empty"] | |
| # Validate input format | |
| pattern = r"^(?P<action_type>\w+)\((?P<button_number>[^,]+)(?:,\s*(?P<query>.+))?\)$" | |
| match = re.match(pattern, user_input) | |
| if not match: | |
| return [False, | |
| "Your input doesn't match the format: action_type(button number), OR if text_box, use text_box(button number, query), eg. noop(12). No indent before input and No extra input before or after action_type(button number)!"] | |
| # Extract parsed components | |
| action_type = match.group("action_type").lower() | |
| button_name = match.group("button_number").strip().lower() | |
| query = match.group("query") # Optional query for `type` | |
| # Validate button number and action type | |
| if button_name not in valid_buttons: | |
| return [False, | |
| "Invalid Button number! Recall: Each button is in the format: `number. button name: action_type`. Correct example: link(3), text_box(2, query)"] # Button number must match exactly (case insensitive) | |
| if action_type != valid_buttons[button_name]: | |
| return [False, | |
| "Invalid action type! Recall: Each button is in the format: `number. button name: action_type`"] # Action type must match the button's specified type | |
| if action_type == "text_box" and query is None: | |
| return [False, | |
| "Missing Query for action type 'text_box'! Recall: use the format: `text_box(button number, query)`"] # `text_box` action requires a query | |
| if action_type != "text_box" and query is not None: | |
| return [False, | |
| "Non-`text_box` action_type cannot take query!"] # Non-`type` actions must not have a query | |
| return [True, 'Pass'] | |