0Learn's picture
Upload 9 files
cf6b61d verified
# chat.py
# Filepath: ai-email-assistant/chat.py
# Implements the chat interface and message handling
import gradio as gr
import os
import json
import logging
from groq import Groq
from dotenv import load_dotenv
from web_search import search_web
from google_sheets_utils import (
get_or_create_spreadsheet,
get_profiles_from_sheet,
delete_profile_from_sheet,
get_profile_summaries,
get_template_summaries,
get_profile_by_id,
get_template_by_id,
save_profile,
update_profile_in_sheet,
delete_template_from_sheet,
save_template,
update_template,
generate_template_id,
generate_profile_id
)
from email_actions import get_templates
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_system_prompt(filename):
try:
with open(filename, 'r') as file:
return file.read().strip()
except FileNotFoundError:
return "System prompt file not found."
def create_chat_interface(sender_summaries, receiver_summaries, template_summaries,
current_sender, current_receiver, current_template,
sender_components, receiver_components, email_actions_components):
system_prompt = load_system_prompt('system_prompt.txt')
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
# Global spreadsheet variable
global spreadsheet
spreadsheet = get_or_create_spreadsheet()
# Initialize Groq client
load_dotenv() # Load environment variables from .env file
groq_api_key = os.getenv('GROQ_API_KEY')
if not groq_api_key:
logger.error("GROQ_API_KEY not found in environment variables")
raise ValueError("GROQ_API_KEY not found in environment variables")
client = Groq(api_key=groq_api_key)
model = "llama3-groq-70b-8192-tool-use-preview"
# Define available functions (tools) for the assistant
def get_profile_summaries_tool(profile_type: str):
"""Get summaries of profiles (Sender or Receiver)"""
return get_profile_summaries(spreadsheet.worksheet(profile_type), profile_type)
def get_template_summaries_tool():
"""Get summaries of email templates"""
return get_template_summaries(spreadsheet)
def get_profile_by_id_tool(profile_type: str, profile_id: str):
"""Get a specific profile by ID"""
return get_profile_by_id(spreadsheet.worksheet(profile_type), profile_id)
def get_template_by_id_tool(template_id: str):
"""Get a specific template by ID"""
return get_template_by_id(spreadsheet.worksheet("Email Templates"), template_id)
def create_profile_tool(profile_type: str, name: str, email: str, position: str, company: str, context: str):
"""Create a new profile"""
return save_profile(spreadsheet.worksheet(profile_type), [profile_type, generate_profile_id(), name, email, position, company, context])
def update_profile_tool(profile_type: str, profile_id: str, name: str, email: str, position: str, company: str, context: str):
"""Update an existing profile"""
return update_profile_in_sheet(spreadsheet.worksheet(profile_type), profile_id, [profile_type, profile_id, name, email, position, company, context])
def delete_profile_tool(profile_type: str, profile_id: str):
"""Delete a profile"""
return delete_profile_from_sheet(spreadsheet.worksheet(profile_type), profile_id)
def create_template_tool(name: str, template_type: str, subject: str, body: str):
"""Create a new email template"""
return save_template(spreadsheet.worksheet("Email Templates"), generate_template_id(), name, template_type, subject, body)
def update_template_tool(template_id: str, name: str, template_type: str, subject: str, body: str):
"""Update an existing email template"""
return update_template(spreadsheet.worksheet("Email Templates"), template_id, name, template_type, subject, body)
def delete_template_tool(template_id: str):
"""Delete an email template"""
return delete_template_from_sheet(spreadsheet.worksheet("Email Templates"), template_id)
# Map function names to actual functions
available_functions = {
"get_profile_summaries": get_profile_summaries_tool,
"get_template_summaries": get_template_summaries_tool,
"get_profile_by_id": get_profile_by_id_tool,
"get_template_by_id": get_template_by_id_tool,
"create_profile": create_profile_tool,
"update_profile": update_profile_tool,
"delete_profile": delete_profile_tool,
"create_template": create_template_tool,
"update_template": update_template_tool,
"delete_template": delete_template_tool,
"search_web": search_web,
}
# Define functions (tools) for the assistant to use
tools = []
for func_name, func in available_functions.items():
tool = {
"type": "function",
"function": {
"name": func_name,
"description": func.__doc__ or "",
"parameters": {
"type": "object",
"properties": {}, # Parameters will be auto-inferred or can be explicitly defined
"required": []
},
},
}
tools.append(tool)
def get_context():
context = "Current context:\n"
if current_sender.value:
context += f"Current Sender: {current_sender.value}\n"
if current_receiver.value:
context += f"Current Receiver: {current_receiver.value}\n"
if current_template.value:
context += f"Current Template: {current_template.value}\n"
context += f"Available Sender Profiles: {', '.join(sender_summaries)}\n"
context += f"Available Receiver Profiles: {', '.join(receiver_summaries)}\n"
context += f"Available Templates: {', '.join(template_summaries)}\n"
return context
# Initialize chat history
chat_history = []
def respond(user_input, history):
context = get_context()
messages = [
{"role": "system", "content": system_prompt},
{"role": "system", "content": context}
]
# Append conversation history
for human_msg, ai_msg in history:
messages.append({"role": "user", "content": human_msg})
messages.append({"role": "assistant", "content": ai_msg})
# Append latest user input
messages.append({"role": "user", "content": user_input})
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
tool_choice="auto",
max_tokens=512 # Adjust as needed
)
response_message = response.choices[0].message
# Check if the assistant made any tool calls
if hasattr(response_message, 'tool_calls') and response_message.tool_calls:
# Process each tool call
for tool_call in response_message.tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
function_to_call = available_functions.get(function_name)
if function_to_call:
# Call the function with arguments
function_response = function_to_call(**function_args)
# Add the function response to the messages
messages.append({
"role": "tool",
"content": json.dumps(function_response),
"tool_call_id": tool_call.id,
})
else:
# If function not found, add an error message
messages.append({
"role": "tool",
"content": json.dumps({"error": f"Function {function_name} not found."}),
"tool_call_id": tool_call.id,
})
# Send the updated messages back to the model for final response
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
tool_choice="auto",
max_tokens=512
)
response_message = response.choices[0].message
assistant_response = response_message.content
# Update chat history
history.append((user_input, assistant_response))
return "", history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
return chatbot, msg, clear