code-interpreter / main.py
ashishbangwal's picture
conversation ID update
df5259e
from modules.functions import call_llm
from fastapi import FastAPI
from pydantic import BaseModel, Field
import os
import sqlite3
import logging
import asyncio
import time
from typing import List, Dict
from typing_extensions import TypedDict
app = FastAPI(debug=True)
# Configure logging
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler("app.log"), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
# SQLite setup
DB_PATH = "app/data/conversations.db"
# In-memory storage for conversations
CONVERSATIONS: Dict[str, List[Dict[str, str]]] = {}
LAST_ACTIVITY: Dict[str, float] = {}
# initialize SQLite database
def init_db():
logger.info("Initializing database")
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute(
"""CREATE TABLE IF NOT EXISTS conversations
(id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id TEXT,
messages TEXT
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)"""
)
conn.commit()
conn.close()
logger.info("Database initialized successfully")
init_db()
def update_db(conversation_id, messages):
logger.info(f"Updating database for conversation: {conversation_id}")
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()
c.execute(
"SELECT COUNT(*) FROM conversations WHERE conversation_id = ?",
(conversation_id,),
)
row_exists = c.fetchone()[0]
if row_exists:
c.execute(
"""UPDATE conversations SET messages = ? WHERE conversation_id = ?""",
(str(messages), conversation_id),
)
else:
c.execute(
f"INSERT INTO conversations (conversation_id, messages) VALUES (?, ?)",
(conversation_id, str(messages)),
)
conn.commit()
conn.close()
logger.info("Database updated successfully")
def get_conversation_from_db(conversation_id):
conn = sqlite3.connect(DB_PATH)
try:
c = conn.cursor()
c.execute(
"""SELECT messages FROM conversations WHERE conversation_id = ?""",
(conversation_id,),
)
conversation = c.fetchone()
if conversation:
return conversation[0]
else:
return None
finally:
conn.close()
async def clear_inactive_conversations():
while True:
logger.info("Clearing inactive conversations")
current_time = time.time()
inactive_convos = [
conv_id
for conv_id, last_time in LAST_ACTIVITY.items()
if current_time - last_time > 1800
] # 30 minutes
for conv_id in inactive_convos:
if conv_id in CONVERSATIONS:
del CONVERSATIONS[conv_id]
if conv_id in LAST_ACTIVITY:
del LAST_ACTIVITY[conv_id]
logger.info(f"Cleared {len(inactive_convos)} inactive conversations")
await asyncio.sleep(60) # Check every minutes
class Output(TypedDict):
type: str
content: str
class UserInput(BaseModel):
ConversationID: str = Field(examples=["123e4567-e89b-12d3-a456-426614174000"])
Query: str = Field(examples=["Nifty 50 Annual return for past 10 years"])
class Response(BaseModel):
response: List[Output] = Field(
examples=[
[
{
"type": "text",
"content": "### Nifty 50 Annual Return for Past 10 Years...",
},
{
"type": "plotly",
"content": '{"data":[{"x":[null,6.75517596225125.....}',
},
]
]
)
executed_code: List[str] = Field(
examples=[
[
"""import folium
m = folium.Map(location=[35, 100....""",
"""from IPython.display import Image
urls = ["https://up""",
]
]
)
@app.post("/response")
async def get_response(user_query: UserInput) -> Response:
conv_id = user_query.ConversationID
query = user_query.Query
if conv_id in CONVERSATIONS:
history = CONVERSATIONS[conv_id] + [{"role": "user", "content": query}]
else:
db_response = get_conversation_from_db(conv_id)
if db_response:
history = eval(db_response) + [{"role": "user", "content": query}]
else:
CONVERSATIONS[conv_id] = []
history = [{"role": "user", "content": query}]
print(history)
results, llm_response, python_code = call_llm(history)
history += [{"role": "assistant", "content": llm_response}]
CONVERSATIONS[conv_id] = history
update_db(conversation_id=conv_id, messages=history)
return {"response": results, "executed_code": python_code} # type:ignore