Coderrs's picture
Upload 6 files
144a2a0 verified
import os
import uuid
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List
# Model and DB libraries
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import chromadb
from youtube_transcript_api import YouTubeTranscriptApi
# --- 1. Constants and Configuration ---
MODEL_REPO = "bartowski/Phi-3.5-mini-instruct_Uncensored-GGUF"
GGUF_FILE = "Phi-3.5-mini-instruct_Uncensored-Q4_K_M.gguf" # Good balance
CHROMA_PATH = "/app/chroma_db" # Path inside the container for persistent storage
COLLECTION_NAME = "chat_history"
# --- 2. Initialize FastAPI app ---
app = FastAPI(
title="Enhanced RAG API with Memory",
description="An API with chat history (ChromaDB) and YouTube analysis.",
version="1.0",
)
# --- 3. Global Variables (will be loaded on startup) ---
llm: Llama = None
chroma_client: chromadb.Client = None
collection: chromadb.Collection = None
# --- 4. Startup Event: Load models and initialize DB ---
@app.on_event("startup")
def load_resources():
global llm, chroma_client, collection
# Load the LLM
print("Downloading and loading LLM...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=GGUF_FILE)
llm = Llama(model_path=model_path, n_ctx=4096, n_gpu_layers=-1, verbose=True)
print("LLM loaded.")
# Initialize ChromaDB client
print("Initializing ChromaDB...")
# This creates a persistent DB client that stores data in the specified path
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
# Get or create the collection to store chat history
collection = chroma_client.get_or_create_collection(name=COLLECTION_NAME)
print("ChromaDB initialized.")
print("API is ready to go! ๐Ÿš€")
# --- 5. Pydantic Models for API requests ---
class ChatRequest(BaseModel):
session_id: str = Field(..., description="Unique identifier for a chat session.")
message: str = Field(..., description="The user's message.")
class YouTubeRequest(BaseModel):
video_url: str = Field(..., description="URL of the YouTube video to analyze.")
# --- 6. API Endpoint for Chat with Memory ---
@app.post("/chat")
def chat_with_memory(request: ChatRequest):
print(f"Received chat request for session: {request.session_id}")
# Step 1: Retrieve relevant chat history from ChromaDB
try:
history = collection.query(
where={"session_id": request.session_id},
n_results=5 # Get the last 5 exchanges
)
# Format history for the prompt
context = "\n".join([f"User: {meta['user_message']}\nAI: {doc}" for doc, meta in zip(history['documents'][0], history['metadatas'][0])])
except Exception as e:
print(f"Error querying ChromaDB: {e}")
context = "" # Start fresh if history fails
# Step 2: Construct the prompt with history
prompt_template = (
"<s><|system|>\nYou are a helpful AI assistant. "
"Use the chat history below to provide a relevant and coherent response.\n\n"
"--- Chat History ---\n{chat_history}\n--- End History ---\n<|end|>\n"
"<|user|>\n{user_message}<|end|>\n<|assistant|>"
)
prompt = prompt_template.format(chat_history=context, user_message=request.message)
# Step 3: Generate a response from the LLM
output = llm(prompt=prompt, max_tokens=256, stop=["<|end|>", "User:"], echo=False)
ai_response = output["choices"][0]["text"].strip()
# Step 4: Save the new exchange to ChromaDB
try:
# We store the AI response as the document and the user message in metadata
doc_id = str(uuid.uuid4())
collection.add(
ids=[doc_id],
documents=[ai_response],
metadatas=[{"session_id": request.session_id, "user_message": request.message}]
)
print(f"Saved new exchange to session {request.session_id}")
except Exception as e:
print(f"Error saving to ChromaDB: {e}")
return {"session_id": request.session_id, "response": ai_response}
# --- 7. API Endpoint for YouTube Video Analysis ---
@app.post("/analyze_youtube")
def analyze_youtube_video(request: YouTubeRequest):
try:
# Extract video ID from URL
video_id = request.video_url.split("v=")[1].split("&")[0]
print(f"Fetching transcript for video ID: {video_id}")
# Get transcript
transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
transcript = " ".join([item['text'] for item in transcript_list])
print("Transcript fetched successfully.")
# Create a prompt for summarization
prompt = (
f"<s><|system|>\nYou are an expert analyst. Summarize the key points of the following YouTube video transcript."
f"<|end|>\n<|user|>\nTranscript: {transcript[:3000]}\n\nSummary:<|end|>\n<|assistant|>" # Truncate to fit context
)
# Get summary from LLM
output = llm(prompt, max_tokens=512, stop=["<|end|>"], echo=False)
summary = output["choices"][0]["text"].strip()
return {"video_url": request.video_url, "summary": summary}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def read_root():
return {"status": "API is running."}