File size: 5,315 Bytes
144a2a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import uuid
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List

# Model and DB libraries
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import chromadb
from youtube_transcript_api import YouTubeTranscriptApi

# --- 1. Constants and Configuration ---
MODEL_REPO = "bartowski/Phi-3.5-mini-instruct_Uncensored-GGUF"
GGUF_FILE = "Phi-3.5-mini-instruct_Uncensored-Q4_K_M.gguf" # Good balance
CHROMA_PATH = "/app/chroma_db" # Path inside the container for persistent storage
COLLECTION_NAME = "chat_history"

# --- 2. Initialize FastAPI app ---
app = FastAPI(
    title="Enhanced RAG API with Memory",
    description="An API with chat history (ChromaDB) and YouTube analysis.",
    version="1.0",
)

# --- 3. Global Variables (will be loaded on startup) ---
llm: Llama = None
chroma_client: chromadb.Client = None
collection: chromadb.Collection = None

# --- 4. Startup Event: Load models and initialize DB ---
@app.on_event("startup")
def load_resources():
    global llm, chroma_client, collection
    
    # Load the LLM
    print("Downloading and loading LLM...")
    model_path = hf_hub_download(repo_id=MODEL_REPO, filename=GGUF_FILE)
    llm = Llama(model_path=model_path, n_ctx=4096, n_gpu_layers=-1, verbose=True)
    print("LLM loaded.")

    # Initialize ChromaDB client
    print("Initializing ChromaDB...")
    # This creates a persistent DB client that stores data in the specified path
    chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
    # Get or create the collection to store chat history
    collection = chroma_client.get_or_create_collection(name=COLLECTION_NAME)
    print("ChromaDB initialized.")
    print("API is ready to go! 🚀")

# --- 5. Pydantic Models for API requests ---
class ChatRequest(BaseModel):
    session_id: str = Field(..., description="Unique identifier for a chat session.")
    message: str = Field(..., description="The user's message.")

class YouTubeRequest(BaseModel):
    video_url: str = Field(..., description="URL of the YouTube video to analyze.")

# --- 6. API Endpoint for Chat with Memory ---
@app.post("/chat")
def chat_with_memory(request: ChatRequest):
    print(f"Received chat request for session: {request.session_id}")
    
    # Step 1: Retrieve relevant chat history from ChromaDB
    try:
        history = collection.query(
            where={"session_id": request.session_id},
            n_results=5 # Get the last 5 exchanges
        )
        # Format history for the prompt
        context = "\n".join([f"User: {meta['user_message']}\nAI: {doc}" for doc, meta in zip(history['documents'][0], history['metadatas'][0])])
    except Exception as e:
        print(f"Error querying ChromaDB: {e}")
        context = "" # Start fresh if history fails

    # Step 2: Construct the prompt with history
    prompt_template = (
        "<s><|system|>\nYou are a helpful AI assistant. "
        "Use the chat history below to provide a relevant and coherent response.\n\n"
        "--- Chat History ---\n{chat_history}\n--- End History ---\n<|end|>\n"
        "<|user|>\n{user_message}<|end|>\n<|assistant|>"
    )
    prompt = prompt_template.format(chat_history=context, user_message=request.message)

    # Step 3: Generate a response from the LLM
    output = llm(prompt=prompt, max_tokens=256, stop=["<|end|>", "User:"], echo=False)
    ai_response = output["choices"][0]["text"].strip()
    
    # Step 4: Save the new exchange to ChromaDB
    try:
        # We store the AI response as the document and the user message in metadata
        doc_id = str(uuid.uuid4())
        collection.add(
            ids=[doc_id],
            documents=[ai_response],
            metadatas=[{"session_id": request.session_id, "user_message": request.message}]
        )
        print(f"Saved new exchange to session {request.session_id}")
    except Exception as e:
        print(f"Error saving to ChromaDB: {e}")

    return {"session_id": request.session_id, "response": ai_response}

# --- 7. API Endpoint for YouTube Video Analysis ---
@app.post("/analyze_youtube")
def analyze_youtube_video(request: YouTubeRequest):
    try:
        # Extract video ID from URL
        video_id = request.video_url.split("v=")[1].split("&")[0]
        print(f"Fetching transcript for video ID: {video_id}")
        
        # Get transcript
        transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
        transcript = " ".join([item['text'] for item in transcript_list])
        print("Transcript fetched successfully.")

        # Create a prompt for summarization
        prompt = (
            f"<s><|system|>\nYou are an expert analyst. Summarize the key points of the following YouTube video transcript."
            f"<|end|>\n<|user|>\nTranscript: {transcript[:3000]}\n\nSummary:<|end|>\n<|assistant|>" # Truncate to fit context
        )
        
        # Get summary from LLM
        output = llm(prompt, max_tokens=512, stop=["<|end|>"], echo=False)
        summary = output["choices"][0]["text"].strip()
        
        return {"video_url": request.video_url, "summary": summary}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
def read_root():
    return {"status": "API is running."}