|
|
import os |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
import sqlite3 |
|
|
from groq import Groq |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(':memory:') |
|
|
c = conn.cursor() |
|
|
c.execute('''CREATE TABLE conversations |
|
|
(id TEXT PRIMARY KEY, messages TEXT, last_updated TIMESTAMP)''') |
|
|
conn.commit() |
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
conversation_id: Optional[str] = None |
|
|
messages: List[Message] |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
conversation_id: str |
|
|
response: str |
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
async def chat(request: ChatRequest): |
|
|
conversation_id = request.conversation_id or str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
c.execute("SELECT messages FROM conversations WHERE id = ?", (conversation_id,)) |
|
|
result = c.fetchone() |
|
|
|
|
|
if result: |
|
|
existing_messages = eval(result[0]) |
|
|
messages = existing_messages + [{"role": msg.role, "content": msg.content} for msg in request.messages] |
|
|
else: |
|
|
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] |
|
|
|
|
|
|
|
|
chat_completion = client.chat.completions.create( |
|
|
messages=messages, |
|
|
model="llama3-8b-8192", |
|
|
) |
|
|
response = chat_completion.choices[0].message.content |
|
|
|
|
|
|
|
|
messages.append({"role": "assistant", "content": response}) |
|
|
c.execute("INSERT OR REPLACE INTO conversations (id, messages, last_updated) VALUES (?, ?, ?)", |
|
|
(conversation_id, str(messages), datetime.now())) |
|
|
conn.commit() |
|
|
|
|
|
return {"conversation_id": conversation_id, "response": response} |
|
|
|
|
|
@app.get("/conversations/{conversation_id}") |
|
|
async def get_conversation(conversation_id: str): |
|
|
c.execute("SELECT messages FROM conversations WHERE id = ?", (conversation_id,)) |
|
|
result = c.fetchone() |
|
|
if result: |
|
|
return {"conversation_id": conversation_id, "messages": eval(result[0])} |
|
|
raise HTTPException(status_code=404, detail="Conversation not found") |
|
|
|
|
|
@app.get("/") |
|
|
async def read_root(): |
|
|
return {"message": "Welcome to the Groq chatbot API with in-memory conversation persistence!"} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |