Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import warnings | |
| import json | |
| from dotenv import load_dotenv | |
| from typing import Dict, Any, List, Optional | |
| import time | |
| from functools import lru_cache | |
| import logging | |
| from langchain.agents import Tool, AgentExecutor | |
| from langchain.tools.retriever import create_retriever_tool | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import AzureOpenAIEmbeddings | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from openai import AzureOpenAI | |
| # Patch Gradio bug | |
| import gradio_client.utils | |
| gradio_client.utils.json_schema_to_python_type = lambda schema, defs=None: "string" | |
| # Load environment variables | |
| load_dotenv() | |
| AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY") | |
| AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT") | |
| AZURE_OPENAI_LLM_DEPLOYMENT = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT") | |
| AZURE_OPENAI_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") | |
| if not all([AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_LLM_DEPLOYMENT, AZURE_OPENAI_EMBEDDING_DEPLOYMENT]): | |
| raise ValueError("Missing one or more Azure OpenAI environment variables.") | |
| warnings.filterwarnings("ignore") | |
| # Embeddings for retriever | |
| embeddings = AzureOpenAIEmbeddings( | |
| azure_deployment=AZURE_OPENAI_EMBEDDING_DEPLOYMENT, | |
| azure_endpoint=AZURE_OPENAI_ENDPOINT, | |
| openai_api_key=AZURE_OPENAI_API_KEY, | |
| openai_api_version="2025-01-01-preview", | |
| chunk_size=1000 | |
| ) | |
| # Get the directory where this script is located | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # Build the absolute path to the faiss_index_sysml directory relative to this script | |
| FAISS_INDEX_PATH = os.path.join(SCRIPT_DIR, "faiss_index_sysml") | |
| # Load FAISS vectorstore | |
| vectorstore = FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True) | |
| # Initialize Azure OpenAI client directly | |
| client = AzureOpenAI( | |
| api_key=AZURE_OPENAI_API_KEY, | |
| api_version="2025-01-01-preview", | |
| azure_endpoint=AZURE_OPENAI_ENDPOINT | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # SysML retriever function | |
| def sysml_retriever(query: str) -> str: | |
| start_time = time.time() | |
| try: | |
| results = vectorstore.similarity_search(query, k=100) | |
| contexts = [doc.page_content for doc in results] | |
| response = "\n\n".join(contexts) | |
| # Log performance metrics | |
| duration = time.time() - start_time | |
| print(f"Retrieval completed in {duration:.2f}s for query: {query[:50]}...") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Retrieval error: {str(e)}") | |
| return "Unable to retrieve information at this time." | |
| # sysml_retriever = create_retriever_tool( | |
| # retriever=vectorstore.as_retriever(), | |
| # name="SysMLRetriever", | |
| # description="Use this to answer questions about SysML diagrams and modeling." | |
| # ) | |
| # Dummy functions | |
| def dummy_weather_lookup(location: str = "London") -> str: | |
| return f"The weather in {location} is sunny and 25°C." | |
| def dummy_time_lookup(timezone: str = "UTC") -> str: | |
| return f"The current time in {timezone} is 3:00 PM." | |
| # Tools definition for OpenAI function calling | |
| tools_definition = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "SysMLRetriever", | |
| "description": "Use this to answer questions about SysML diagrams and modeling.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": { | |
| "type": "string", | |
| "description": "The search query to find information about SysML" | |
| } | |
| }, | |
| "required": ["query"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "WeatherLookup", | |
| "description": "Use this to look up the current weather in a specified location.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": { | |
| "type": "string", | |
| "description": "The location to look up the weather for" | |
| } | |
| }, | |
| "required": ["location"] | |
| } | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "TimeLookup", | |
| "description": "Use this to look up the current time in a specified timezone.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "timezone": { | |
| "type": "string", | |
| "description": "The timezone to look up the current time for" | |
| } | |
| }, | |
| "required": ["timezone"] | |
| } | |
| } | |
| } | |
| ] | |
| # Tool execution mapping | |
| tool_mapping = { | |
| "SysMLRetriever": sysml_retriever, | |
| "WeatherLookup": dummy_weather_lookup, | |
| "TimeLookup": dummy_time_lookup | |
| } | |
| # Convert chat history | |
| def convert_history_to_messages(history): | |
| messages = [] | |
| for user, bot in history: | |
| messages.append({"role": "user", "content": user}) | |
| messages.append({"role": "assistant", "content": bot}) | |
| return messages | |
| # Main chatbot function with direct function calling | |
| def sysml_chatbot(message, history): | |
| # Convert history to messages format | |
| chat_messages = convert_history_to_messages(history) | |
| # Add system message at beginning | |
| full_messages = [ | |
| {"role": "system", "content": "You are a helpful SysML modeling assistant and also a capable smart Assistant "} | |
| ] | |
| full_messages.extend(chat_messages) | |
| # Add current user message | |
| full_messages.append({"role": "user", "content": message}) | |
| try: | |
| # First call to get either a direct answer or a function call | |
| response = client.chat.completions.create( | |
| model=AZURE_OPENAI_LLM_DEPLOYMENT, | |
| messages=full_messages, | |
| tools=tools_definition, | |
| tool_choice={"type": "function", "function": {"name": "SysMLRetriever"}} | |
| ) | |
| assistant_message = response.choices[0].message | |
| # Check if the model wants to call a function | |
| if assistant_message.tool_calls: | |
| # Get the function call details | |
| tool_call = assistant_message.tool_calls[0] | |
| function_name = tool_call.function.name | |
| function_args = json.loads(tool_call.function.arguments) | |
| print("Attempting function calling...") | |
| # Execute the function | |
| if function_name in tool_mapping: | |
| function_response = tool_mapping[function_name](**function_args) | |
| # Append the assistant's request and the function response to messages | |
| full_messages.append({"role": "assistant", "content": None, "tool_calls": [ | |
| {"id": tool_call.id, "type": "function", "function": {"name": function_name, "arguments": tool_call.function.arguments}} | |
| ]}) | |
| full_messages.append({ | |
| "role": "tool", | |
| "tool_call_id": tool_call.id, | |
| "content": function_response | |
| }) | |
| # Second call to get the final answer based on the function result | |
| second_response = client.chat.completions.create( | |
| model=AZURE_OPENAI_LLM_DEPLOYMENT, | |
| messages=full_messages | |
| ) | |
| answer = second_response.choices[0].message.content | |
| print("Getting final response after function execution...") | |
| print(f"Function '{function_name}' executed successfully. Response: {answer}") | |
| else: | |
| answer = f"I tried to use a function '{function_name}' that's not available. Let me try again with general knowledge: SysML is a modeling language for systems engineering that helps visualize and analyze complex systems." | |
| else: | |
| # Model provided a direct answer | |
| answer = assistant_message.content | |
| history.append((message, answer)) | |
| return answer, history | |
| except Exception as e: | |
| print(f"Error in function calling: {str(e)}") | |
| # Fallback to a direct response without function calling | |
| try: | |
| simple_messages = [ | |
| {"role": "system", "content": "You are a helpful SysML modeling assistant."} | |
| ] | |
| simple_messages.extend(chat_messages) | |
| simple_messages.append({"role": "user", "content": message}) | |
| fallback_response = client.chat.completions.create( | |
| model=AZURE_OPENAI_LLM_DEPLOYMENT, | |
| messages=simple_messages | |
| ) | |
| answer = fallback_response.choices[0].message.content | |
| except Exception as fallback_error: | |
| print(f"Error in fallback: {str(fallback_error)}") | |
| answer = "I'm having trouble accessing my tools right now. SysML is a modeling language used in systems engineering to visualize and analyze complex systems through various diagram types." | |
| history.append((message, answer)) | |
| return answer, history | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## SysModeler Chatbot") | |
| chatbot = gr.Chatbot(height=600) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| msg = gr.Textbox( | |
| placeholder="Ask me about SysML diagrams or concepts...", | |
| lines=3, | |
| show_label=False | |
| ) | |
| with gr.Column(scale=1, min_width=50): | |
| submit_btn = gr.Button("➤") | |
| clear = gr.Button("Clear") | |
| state = gr.State(history) | |
| submit_btn.click(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot]) | |
| msg.submit(fn=sysml_chatbot, inputs=[msg, state], outputs=[msg, chatbot]) # still supports enter key | |
| clear.click(fn=lambda: ([], ""), inputs=None, outputs=[chatbot, msg]) | |
| if __name__ == "__main__": | |
| demo.launch() | |