import os import uuid from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from typing import List # Model and DB libraries from llama_cpp import Llama from huggingface_hub import hf_hub_download import chromadb from youtube_transcript_api import YouTubeTranscriptApi # --- 1. Constants and Configuration --- MODEL_REPO = "bartowski/Phi-3.5-mini-instruct_Uncensored-GGUF" GGUF_FILE = "Phi-3.5-mini-instruct_Uncensored-Q4_K_M.gguf" # Good balance CHROMA_PATH = "/app/chroma_db" # Path inside the container for persistent storage COLLECTION_NAME = "chat_history" # --- 2. Initialize FastAPI app --- app = FastAPI( title="Enhanced RAG API with Memory", description="An API with chat history (ChromaDB) and YouTube analysis.", version="1.0", ) # --- 3. Global Variables (will be loaded on startup) --- llm: Llama = None chroma_client: chromadb.Client = None collection: chromadb.Collection = None # --- 4. Startup Event: Load models and initialize DB --- @app.on_event("startup") def load_resources(): global llm, chroma_client, collection # Load the LLM print("Downloading and loading LLM...") model_path = hf_hub_download(repo_id=MODEL_REPO, filename=GGUF_FILE) llm = Llama(model_path=model_path, n_ctx=4096, n_gpu_layers=-1, verbose=True) print("LLM loaded.") # Initialize ChromaDB client print("Initializing ChromaDB...") # This creates a persistent DB client that stores data in the specified path chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) # Get or create the collection to store chat history collection = chroma_client.get_or_create_collection(name=COLLECTION_NAME) print("ChromaDB initialized.") print("API is ready to go! 🚀") # --- 5. Pydantic Models for API requests --- class ChatRequest(BaseModel): session_id: str = Field(..., description="Unique identifier for a chat session.") message: str = Field(..., description="The user's message.") class YouTubeRequest(BaseModel): video_url: str = Field(..., description="URL of the YouTube video to analyze.") # --- 6. API Endpoint for Chat with Memory --- @app.post("/chat") def chat_with_memory(request: ChatRequest): print(f"Received chat request for session: {request.session_id}") # Step 1: Retrieve relevant chat history from ChromaDB try: history = collection.query( where={"session_id": request.session_id}, n_results=5 # Get the last 5 exchanges ) # Format history for the prompt context = "\n".join([f"User: {meta['user_message']}\nAI: {doc}" for doc, meta in zip(history['documents'][0], history['metadatas'][0])]) except Exception as e: print(f"Error querying ChromaDB: {e}") context = "" # Start fresh if history fails # Step 2: Construct the prompt with history prompt_template = ( "<|system|>\nYou are a helpful AI assistant. " "Use the chat history below to provide a relevant and coherent response.\n\n" "--- Chat History ---\n{chat_history}\n--- End History ---\n<|end|>\n" "<|user|>\n{user_message}<|end|>\n<|assistant|>" ) prompt = prompt_template.format(chat_history=context, user_message=request.message) # Step 3: Generate a response from the LLM output = llm(prompt=prompt, max_tokens=256, stop=["<|end|>", "User:"], echo=False) ai_response = output["choices"][0]["text"].strip() # Step 4: Save the new exchange to ChromaDB try: # We store the AI response as the document and the user message in metadata doc_id = str(uuid.uuid4()) collection.add( ids=[doc_id], documents=[ai_response], metadatas=[{"session_id": request.session_id, "user_message": request.message}] ) print(f"Saved new exchange to session {request.session_id}") except Exception as e: print(f"Error saving to ChromaDB: {e}") return {"session_id": request.session_id, "response": ai_response} # --- 7. API Endpoint for YouTube Video Analysis --- @app.post("/analyze_youtube") def analyze_youtube_video(request: YouTubeRequest): try: # Extract video ID from URL video_id = request.video_url.split("v=")[1].split("&")[0] print(f"Fetching transcript for video ID: {video_id}") # Get transcript transcript_list = YouTubeTranscriptApi.get_transcript(video_id) transcript = " ".join([item['text'] for item in transcript_list]) print("Transcript fetched successfully.") # Create a prompt for summarization prompt = ( f"<|system|>\nYou are an expert analyst. Summarize the key points of the following YouTube video transcript." f"<|end|>\n<|user|>\nTranscript: {transcript[:3000]}\n\nSummary:<|end|>\n<|assistant|>" # Truncate to fit context ) # Get summary from LLM output = llm(prompt, max_tokens=512, stop=["<|end|>"], echo=False) summary = output["choices"][0]["text"].strip() return {"video_url": request.video_url, "summary": summary} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") def read_root(): return {"status": "API is running."}