Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from sqlalchemy.exc import SQLAlchemyError | |
| from utils import InnovativeIdea, init_db, get_db, SessionLocal, get_llm_response | |
| from data_models import IdeaForm | |
| from chatbot import InnovativeIdeaChatbot | |
| from config import MODELS, DEFAULT_SYSTEM_PROMPT, STAGES | |
| import logging | |
| import os | |
| from typing import Dict, Any | |
| import re | |
| import time | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| def create_gradio_interface(): | |
| innovative_chatbot = InnovativeIdeaChatbot() | |
| default_stage = STAGES[0]["name"] | |
| # Initialize the database | |
| try: | |
| init_db() | |
| db = next(get_db()) | |
| initial_idea = db.query(InnovativeIdea).first() | |
| if initial_idea is None: | |
| logging.info("No initial idea found in the database. Creating a new one.") | |
| initial_idea = InnovativeIdea() | |
| db.add(initial_idea) | |
| db.commit() | |
| db.refresh(initial_idea) | |
| # Create form_fields while the session is still open | |
| form_fields = { | |
| stage["name"]: gr.Textbox( | |
| label=stage["question"], | |
| placeholder=stage["example"], | |
| value=getattr(initial_idea, stage["field"], ""), | |
| visible=(stage["name"] == default_stage), | |
| interactive=False | |
| ) for stage in STAGES | |
| } | |
| # Now we can safely close the session | |
| db.close() | |
| except SQLAlchemyError as e: | |
| logging.error(f"Database initialization failed: {str(e)}") | |
| raise RuntimeError(f"Failed to initialize database: {str(e)}") | |
| def chatbot_function(message, history, model, system_prompt, thinking_budget, current_stage): | |
| try: | |
| # If this is the first message, get the initial greeting | |
| if not history: | |
| initial_greeting = innovative_chatbot.get_initial_greeting() | |
| history.append((None, initial_greeting)) | |
| yield history, "", "" | |
| return | |
| for partial_response in innovative_chatbot.process_stage_input_stream(current_stage, message, model, system_prompt, thinking_budget): | |
| chat_history, form_data = partial_response | |
| history.append((message, chat_history[-1][1])) | |
| yield history, form_data.get(current_stage, ""), "" | |
| # Update the database with the new form data | |
| db = SessionLocal() | |
| idea = db.query(InnovativeIdea).first() | |
| for key, value in form_data.items(): | |
| if key == 'team_roles' and isinstance(value, list): | |
| value = ','.join(value) # Convert list to string for database storage | |
| setattr(idea, key, value) | |
| db.commit() | |
| db.close() | |
| except Exception as e: | |
| logging.error(f"An error occurred in chatbot_function: {str(e)}", exc_info=True) | |
| yield history + [(None, f"An error occurred: {str(e)}")], "", "" | |
| def fill_form(stage, model, thinking_budget): | |
| form_data = innovative_chatbot.fill_out_form(stage, model, thinking_budget) | |
| return [form_data.get(stage["field"], "") for stage in STAGES] | |
| def clear_chat(): | |
| # Reset the database to an empty form | |
| db = SessionLocal() | |
| idea = db.query(InnovativeIdea).first() | |
| empty_form = IdeaForm() | |
| for key, value in empty_form.dict().items(): | |
| setattr(idea, key, value) | |
| db.commit() | |
| db.close() | |
| chat_history, form_data = innovative_chatbot.reset() | |
| return chat_history, *[form_data.get(stage["field"], "") for stage in STAGES] | |
| def start_over(): | |
| chat_history, form_data, initial_stage = innovative_chatbot.start_over() | |
| return ( | |
| chat_history, # Update the chatbot with the new chat history | |
| "", # Clear the message input | |
| *[form_data.get(stage["field"], "") for stage in STAGES], # Reset all form fields | |
| gr.update(value=initial_stage) # Reset the stage selection | |
| ) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Innovative Idea Generator") | |
| mode = gr.Radio(["Chatbot", "Direct Input"], label="Mode", value="Chatbot") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Conversation", height=500) | |
| msg = gr.Textbox(label="Your input", placeholder="Type your brilliant idea here...") | |
| with gr.Row(): | |
| submit = gr.Button("Submit") | |
| clear = gr.Button("Clear Chat") | |
| start_over_btn = gr.Button("Start Over") | |
| with gr.Column(scale=1): | |
| stages = gr.Radio( | |
| choices=[stage["name"] for stage in STAGES], | |
| label="Ideation Stages", | |
| value=default_stage | |
| ) | |
| form_fields = { | |
| stage["name"]: gr.Textbox( | |
| label=stage["question"], | |
| placeholder=stage["example"], | |
| value=getattr(initial_idea, stage["field"], ""), | |
| visible=(stage["name"] == default_stage), | |
| interactive=False | |
| ) for stage in STAGES | |
| } | |
| fill_form_btn = gr.Button("Fill out Form") | |
| submit_form_btn = gr.Button("Submit Form", visible=False) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| model = gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0]) | |
| system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=5) | |
| thinking_budget = gr.Slider(minimum=1, maximum=4098, value=2048, step=1, label="Max New Tokens") | |
| api_key = gr.Textbox(label="Hugging Face API Key", type="password") | |
| # Event handlers | |
| msg.submit(chatbot_function, | |
| inputs=[msg, chatbot, model, system_prompt, thinking_budget, stages], | |
| outputs=[chatbot, form_fields[default_stage], msg]) | |
| submit.click(chatbot_function, | |
| inputs=[msg, chatbot, model, system_prompt, thinking_budget, stages], | |
| outputs=[chatbot, form_fields[default_stage], msg]) | |
| fill_form_btn.click(fill_form, | |
| inputs=[stages, model, thinking_budget], | |
| outputs=list(form_fields.values())) | |
| clear.click(clear_chat, | |
| outputs=[chatbot] + list(form_fields.values())) | |
| # Update form field visibility based on selected stage | |
| stages.change( | |
| lambda s: [gr.update(visible=(stage["name"] == s)) for stage in STAGES], | |
| inputs=[stages], | |
| outputs=list(form_fields.values()) | |
| ) | |
| # Update API key when changed | |
| api_key.change(innovative_chatbot.set_api_key, inputs=[api_key]) | |
| # Toggle between chatbot and direct input mode | |
| def toggle_mode(new_mode): | |
| if new_mode == "Direct Input": | |
| return [gr.update(visible=False)] * 3 + [gr.update(interactive=True)] * len(STAGES) + [gr.update(visible=True)] | |
| else: | |
| return [gr.update(visible=True)] * 3 + [gr.update(interactive=False)] * len(STAGES) + [gr.update(visible=False)] | |
| mode.change( | |
| toggle_mode, | |
| inputs=[mode], | |
| outputs=[chatbot, msg, submit] + list(form_fields.values()) + [submit_form_btn] | |
| ) | |
| # Handle direct form submission | |
| submit_form_btn.click( | |
| lambda *values: values, | |
| inputs=[form_fields[stage["name"]] for stage in STAGES], | |
| outputs=[form_fields[stage["name"]] for stage in STAGES] | |
| ) | |
| # Add this new event handler for the Start Over button | |
| start_over_btn.click( | |
| start_over, | |
| outputs=[chatbot, msg] + [form_fields[stage["name"]] for stage in STAGES] + [stages] | |
| ) | |
| # Add this new event handler to display the initial greeting when the interface loads | |
| demo.load(lambda: ([[None, innovative_chatbot.get_initial_greeting()]], ""), | |
| outputs=[chatbot, msg]) | |
| # Add this new event handler to update form fields when they change | |
| for stage in STAGES: | |
| form_fields[stage["name"]].change( | |
| lambda value, s=stage["name"]: innovative_chatbot.update_form_field(s, value), | |
| inputs=[form_fields[stage["name"]]], | |
| outputs=[form_fields[stage["name"]]] | |
| ) | |
| return demo | |
| def main(): | |
| try: | |
| demo = create_gradio_interface() | |
| return demo | |
| except ImportError as e: | |
| logging.error(f"Import error: {str(e)}", exc_info=True) | |
| print(f"An import error occurred: {str(e)}") | |
| print("Please check your import statements and ensure there are no circular dependencies.") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Failed to initialize application: {str(e)}", exc_info=True) | |
| print(f"An unexpected error occurred: {str(e)}") | |
| print("Please check the log file for more details.") | |
| return None | |
| if __name__ == "__main__": | |
| try: | |
| demo = main() | |
| if demo: | |
| demo.launch() | |
| except Exception as e: | |
| logging.error(f"Failed to start the application: {str(e)}", exc_info=True) | |
| print(f"An error occurred while starting the application: {str(e)}") | |
| print("Please check the log file for more details.") | |
| # You might want to add a more user-friendly error message or UI here | |