Spaces:
Sleeping
Sleeping
| from typing import List, Tuple, Dict, Any, Optional | |
| import logging | |
| import re | |
| from data_models import IdeaForm, IDEA_STAGES | |
| from config import DEFAULT_SYSTEM_PROMPT, STAGES | |
| from utils import ( | |
| get_llm_response, extract_form_data, save_idea_to_database, | |
| load_idea_from_database, update_idea_in_database, | |
| get_db, clear_database, init_db, create_tables, | |
| perform_web_search, optimize_search_query, | |
| SessionLocal, InnovativeIdea | |
| ) | |
| class InnovativeIdeaChatbot: | |
| def __init__(self): | |
| create_tables() | |
| init_db() | |
| self.idea_form = IdeaForm() | |
| self.chat_history = [] | |
| self.idea_id = None | |
| self.current_stage = None | |
| self.api_key = None | |
| self.add_system_message(self.get_initial_greeting()) | |
| def get_initial_greeting(self) -> str: | |
| greeting = """ | |
| Welcome to the Innovative Idea Generator! I'm Myamoto, your AI assistant designed to help you refine and develop your innovative ideas. | |
| Here's how we'll work together: | |
| 1. We'll go through 10 stages to explore different aspects of your idea. | |
| 2. At each stage, I'll ask you questions and provide feedback to help you think deeper about your concept. | |
| 3. You can ask me questions at any time or request more information on a topic. | |
| 4. If you want to perform a web search for additional information, just start your message with '@' followed by your search query. | |
| 5. When you're ready to move to the next stage, simply type 'next'. | |
| Let's start by exploring your innovative idea! What's the name of your idea, or would you like help coming up with one? | |
| """ | |
| self.greeted = True | |
| return greeting | |
| def add_system_message(self, message: str): | |
| self.chat_history.append(("System", message)) | |
| def set_api_key(self, api_key: str): | |
| self.api_key = api_key | |
| def activate_stage(self, stage_name: str) -> Optional[str]: | |
| self.current_stage = stage_name | |
| for stage in STAGES: | |
| if stage["name"] == stage_name: | |
| return f"Let's work on the '{stage_name}' stage. {stage['question']}" | |
| return None | |
| def process_stage_input(self, stage_name: str, message: str, model: str, system_prompt: str, thinking_budget: int) -> Tuple[List[Tuple[str, str]], Dict[str, Any]]: | |
| if self.current_stage != stage_name: | |
| activation_message = self.activate_stage(stage_name) | |
| if activation_message is None: | |
| error_message = f"Error: Unable to activate stage '{stage_name}'. Please check if the stage name is correct." | |
| self.chat_history.append(("System", error_message)) | |
| return self.chat_history, self.idea_form.dict() | |
| self.chat_history.append(("System", activation_message)) | |
| # Check for web search request | |
| if message.startswith('@'): | |
| search_query = message[1:].strip() | |
| optimized_query = optimize_search_query(search_query, model) | |
| search_results = perform_web_search(optimized_query) | |
| self.chat_history.append(("Human", message)) | |
| self.chat_history.append(("AI", f"Here are the search results for '{optimized_query}':\n\n{search_results}")) | |
| return self.chat_history, self.idea_form.dict() | |
| # Generate the prompt for the current stage | |
| stage_prompt = self.generate_prompt_for_stage(stage_name) | |
| # Use the DEFAULT_SYSTEM_PROMPT from config.py | |
| formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( | |
| current_stage=stage_name, | |
| stage_prompt=stage_prompt | |
| ) | |
| # Combine the formatted system prompt and user's input | |
| combined_prompt = f"{formatted_system_prompt}\n\nUser input: {message}" | |
| # Get LLM response | |
| llm_response = get_llm_response(combined_prompt, model, thinking_budget, self.api_key) | |
| # Parse the LLM response to extract only the user-facing content | |
| parsed_response = self.parse_llm_response(llm_response) | |
| # Add the interaction to chat history | |
| self.chat_history.append(("Human", message)) | |
| self.chat_history.append(("AI", parsed_response)) | |
| # Extract form data from the LLM response | |
| form_data = extract_form_data(llm_response) | |
| # Update the idea form | |
| if stage_name in form_data: | |
| setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data[stage_name]) | |
| return self.chat_history, self.idea_form.dict() | |
| def parse_llm_response(self, response: str) -> str: | |
| # Remove content within <form_data> tags | |
| response = re.sub(r'<form_data>.*?</form_data>', '', response, flags=re.DOTALL) | |
| # Remove content within <reflection> tags | |
| response = re.sub(r'<reflection>.*?</reflection>', '', response, flags=re.DOTALL) | |
| # Remove content within <analysis> tags | |
| response = re.sub(r'<analysis>.*?</analysis>', '', response, flags=re.DOTALL) | |
| # Remove content within <summary> tags | |
| response = re.sub(r'<summary>.*?</summary>', '', response, flags=re.DOTALL) | |
| # Remove content within <step> tags | |
| response = re.sub(r'<step>.*?</step>', '', response, flags=re.DOTALL) | |
| # Remove any remaining HTML-like tags | |
| response = re.sub(r'<[^>]+>', '', response) | |
| # Remove extra whitespace and newlines | |
| response = re.sub(r'\s+', ' ', response).strip() | |
| return response | |
| def fill_out_form(self, current_stage: str, model: str, thinking_budget: int) -> Dict[str, str]: | |
| form_data = {} | |
| for stage in STAGES: | |
| stage_name = stage["name"] | |
| if stage_name == current_stage: | |
| # Generate new data for the current stage | |
| form_data[stage["field"]] = self.generate_form_data(stage_name, model, thinking_budget) | |
| else: | |
| # Use existing data for other stages | |
| form_data[stage["field"]] = getattr(self.idea_form, stage["field"], "") | |
| # Update the idea form | |
| for stage in STAGES: | |
| setattr(self.idea_form, stage["field"], form_data[stage["field"]]) | |
| # Save to database | |
| try: | |
| new_session = SessionLocal() | |
| if self.idea_id: | |
| update_idea_in_database(self.idea_id, self.idea_form, new_session) | |
| else: | |
| self.idea_id = save_idea_to_database(self.idea_form, new_session) | |
| new_session.commit() | |
| except Exception as e: | |
| logging.error(f"Error saving idea to database: {str(e)}") | |
| new_session.rollback() | |
| finally: | |
| new_session.close() | |
| return form_data | |
| def generate_prompt_for_stage(self, stage: str) -> str: | |
| for s in IDEA_STAGES: | |
| if s.name == stage: | |
| return f"We are currently working on the '{stage}' stage. {s.question}" | |
| return f"We are currently working on the '{stage}' stage. Please provide relevant information." | |
| def reset(self): | |
| self.chat_history = [] | |
| self.idea_form = IdeaForm() | |
| self.idea_id = None | |
| self.current_stage = None | |
| self.add_system_message(self.get_initial_greeting()) | |
| try: | |
| new_session = SessionLocal() | |
| clear_database(new_session) | |
| new_session.commit() | |
| except Exception as e: | |
| logging.error(f"Error clearing database: {str(e)}") | |
| new_session.rollback() | |
| finally: | |
| new_session.close() | |
| return self.chat_history, self.idea_form.dict() | |
| def start_over(self): | |
| self.chat_history = [] | |
| self.idea_form = IdeaForm() | |
| self.current_stage = None | |
| self.add_system_message(self.get_initial_greeting()) | |
| try: | |
| new_session = SessionLocal() | |
| # Clear the existing database | |
| clear_database(new_session) | |
| # Create a new empty idea | |
| new_idea = InnovativeIdea() | |
| new_session.add(new_idea) | |
| new_session.commit() | |
| new_session.refresh(new_idea) | |
| # Update the idea_id | |
| self.idea_id = new_idea.id | |
| new_session.close() | |
| except Exception as e: | |
| logging.error(f"Error in start_over: {str(e)}") | |
| if 'new_session' in locals(): | |
| new_session.rollback() | |
| new_session.close() | |
| return self.chat_history, self.idea_form.dict(), STAGES[0]["name"] | |
| def update_idea_form(self, stage_name: str, form_data: str): | |
| setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data) | |
| try: | |
| new_session = SessionLocal() | |
| if self.idea_id: | |
| update_idea_in_database(self.idea_id, self.idea_form, new_session) | |
| else: | |
| self.idea_id = save_idea_to_database(self.idea_form, new_session) | |
| new_session.commit() | |
| except Exception as e: | |
| logging.error(f"Error updating idea form: {str(e)}") | |
| new_session.rollback() | |
| finally: | |
| new_session.close() | |
| def generate_form_data(self, stage: str, model: str, thinking_budget: int) -> str: | |
| # Prepare the conversation history for the LLM | |
| conversation = "\n".join([f"{role}: {message}" for role, message in self.chat_history]) | |
| stage_prompt = self.generate_prompt_for_stage(stage) | |
| formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( | |
| current_stage=stage, | |
| stage_prompt=stage_prompt | |
| ) | |
| prompt = f""" | |
| {formatted_system_prompt} | |
| Based on the following conversation, extract the relevant information for the '{stage}' stage of the innovative idea: | |
| {conversation} | |
| Please provide a concise summary for the '{stage}' stage, focusing only on the information relevant to this stage. | |
| Your response should be structured as follows: | |
| 1. A brief analysis of the conversation related to this stage. | |
| 2. A concise summary of the key points relevant to this stage. | |
| 3. A suggested form entry for this stage, enclosed in <form_data></form_data> tags. | |
| The form entry should be in the format: "{stage}: Content" | |
| Remember to keep the form entry concise and directly related to the '{stage}' stage. Do not include information from other stages in the form entry. | |
| """ | |
| # Get LLM response | |
| llm_response = get_llm_response(prompt, model, thinking_budget, self.api_key) | |
| # Extract form data from the LLM response | |
| form_data = extract_form_data(llm_response) | |
| return form_data.get(stage, "") | |
| def process_stage_input_stream(self, stage_name: str, message: str, model: str, system_prompt: str, thinking_budget: int): | |
| if self.current_stage != stage_name: | |
| activation_message = self.activate_stage(stage_name) | |
| if activation_message is None: | |
| error_message = f"Error: Unable to activate stage '{stage_name}'. Please check if the stage name is correct." | |
| self.chat_history.append(("System", error_message)) | |
| yield self.chat_history, self.idea_form.dict() | |
| return | |
| self.chat_history.append(("System", activation_message)) | |
| # Check for web search request | |
| if message.startswith('@'): | |
| search_query = message[1:].strip() | |
| optimized_query = optimize_search_query(search_query, model) | |
| search_results = perform_web_search(optimized_query) | |
| self.chat_history.append(("Human", message)) | |
| self.chat_history.append(("AI", f"Here are the search results for '{optimized_query}':\n\n{search_results}")) | |
| yield self.chat_history, self.idea_form.dict() | |
| return | |
| # Generate the prompt for the current stage | |
| stage_prompt = self.generate_prompt_for_stage(stage_name) | |
| formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( | |
| current_stage=stage_name, | |
| stage_prompt=stage_prompt | |
| ) | |
| combined_prompt = f"{formatted_system_prompt}\n\nUser input: {message}" | |
| # Get LLM response | |
| llm_response = get_llm_response(combined_prompt, model, thinking_budget, self.api_key) | |
| parsed_response = self.parse_llm_response(llm_response) | |
| self.chat_history.append(("Human", message)) | |
| self.chat_history.append(("AI", parsed_response)) | |
| form_data = extract_form_data(llm_response) | |
| if stage_name in form_data: | |
| setattr(self.idea_form, stage_name.lower().replace(" ", "_"), form_data[stage_name]) | |
| yield self.chat_history, self.idea_form.dict() | |
| def fill_out_form_stream(self, current_stage: str, model: str, thinking_budget: int): | |
| form_data = {} | |
| for stage in IDEA_STAGES: | |
| stage_name = stage.name | |
| if stage_name == current_stage: | |
| form_data[stage_name] = self.generate_form_data(stage_name, model, thinking_budget) | |
| else: | |
| form_data[stage_name] = getattr(self.idea_form, stage.field, "") | |
| yield form_data | |
| # Update the idea form | |
| for stage in IDEA_STAGES: | |
| setattr(self.idea_form, stage.field, form_data[stage.name]) | |
| # Save to database | |
| try: | |
| new_session = SessionLocal() | |
| if self.idea_id: | |
| update_idea_in_database(self.idea_id, self.idea_form, new_session) | |
| else: | |
| self.idea_id = save_idea_to_database(self.idea_form, new_session) | |
| new_session.commit() | |
| except Exception as e: | |
| logging.error(f"Error saving idea to database: {str(e)}") | |
| new_session.rollback() | |
| finally: | |
| new_session.close() | |
| def generate_form_data_stream(self, stage: str, model: str, thinking_budget: int): | |
| conversation = "\n".join([f"{role}: {message}" for role, message in self.chat_history]) | |
| stage_prompt = self.generate_prompt_for_stage(stage) | |
| formatted_system_prompt = DEFAULT_SYSTEM_PROMPT.format( | |
| current_stage=stage, | |
| stage_prompt=stage_prompt | |
| ) | |
| prompt = f""" | |
| {formatted_system_prompt} | |
| Based on the following conversation, extract the relevant information for the '{stage}' stage of the innovative idea: | |
| {conversation} | |
| Please provide a concise summary for the '{stage}' stage, focusing only on the information relevant to this stage. | |
| Your response should be structured as follows: | |
| 1. A brief analysis of the conversation related to this stage. | |
| 2. A concise summary of the key points relevant to this stage. | |
| 3. A suggested form entry for this stage, enclosed in <form_data></form_data> tags. | |
| The form entry should be in the format: "{stage}: Content" | |
| Remember to keep the form entry concise and directly related to the '{stage}' stage. Do not include information from other stages in the form entry. | |
| """ | |
| llm_response = get_llm_response(prompt, model, thinking_budget, self.api_key) | |
| form_data = extract_form_data(llm_response) | |
| return form_data.get(stage, "") | |
| def update_form_field(self, stage_name: str, value: str): | |
| field_name = stage_name.lower().replace(" ", "_") | |
| if field_name == 'team_roles': | |
| value = value.split(',') # Convert string to list for team_roles | |
| setattr(self.idea_form, field_name, value) | |
| try: | |
| # Create a new session for this operation | |
| new_session = SessionLocal() | |
| if self.idea_id: | |
| update_idea_in_database(self.idea_id, self.idea_form, new_session) | |
| else: | |
| self.idea_id = save_idea_to_database(self.idea_form, new_session) | |
| new_session.commit() | |
| except Exception as e: | |
| logging.error(f"Error updating form field: {str(e)}") | |
| # If an error occurs, rollback the new session | |
| new_session.rollback() | |
| finally: | |
| # Always close the new session | |
| new_session.close() | |
| return value | |
| # Add this new method to handle the "Fill out form" button click | |
| def fill_out_form_button(chatbot: InnovativeIdeaChatbot, current_stage: str, model: str, thinking_budget: int): | |
| form_data = chatbot.fill_out_form(current_stage, model, thinking_budget) | |
| return {stage["field"]: form_data[stage["field"]] for stage in STAGES} |