Spaces:
Build error
Build error
File size: 5,315 Bytes
144a2a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import uuid
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List
# Model and DB libraries
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import chromadb
from youtube_transcript_api import YouTubeTranscriptApi
# --- 1. Constants and Configuration ---
MODEL_REPO = "bartowski/Phi-3.5-mini-instruct_Uncensored-GGUF"
GGUF_FILE = "Phi-3.5-mini-instruct_Uncensored-Q4_K_M.gguf" # Good balance
CHROMA_PATH = "/app/chroma_db" # Path inside the container for persistent storage
COLLECTION_NAME = "chat_history"
# --- 2. Initialize FastAPI app ---
app = FastAPI(
title="Enhanced RAG API with Memory",
description="An API with chat history (ChromaDB) and YouTube analysis.",
version="1.0",
)
# --- 3. Global Variables (will be loaded on startup) ---
llm: Llama = None
chroma_client: chromadb.Client = None
collection: chromadb.Collection = None
# --- 4. Startup Event: Load models and initialize DB ---
@app.on_event("startup")
def load_resources():
global llm, chroma_client, collection
# Load the LLM
print("Downloading and loading LLM...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=GGUF_FILE)
llm = Llama(model_path=model_path, n_ctx=4096, n_gpu_layers=-1, verbose=True)
print("LLM loaded.")
# Initialize ChromaDB client
print("Initializing ChromaDB...")
# This creates a persistent DB client that stores data in the specified path
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
# Get or create the collection to store chat history
collection = chroma_client.get_or_create_collection(name=COLLECTION_NAME)
print("ChromaDB initialized.")
print("API is ready to go! 🚀")
# --- 5. Pydantic Models for API requests ---
class ChatRequest(BaseModel):
session_id: str = Field(..., description="Unique identifier for a chat session.")
message: str = Field(..., description="The user's message.")
class YouTubeRequest(BaseModel):
video_url: str = Field(..., description="URL of the YouTube video to analyze.")
# --- 6. API Endpoint for Chat with Memory ---
@app.post("/chat")
def chat_with_memory(request: ChatRequest):
print(f"Received chat request for session: {request.session_id}")
# Step 1: Retrieve relevant chat history from ChromaDB
try:
history = collection.query(
where={"session_id": request.session_id},
n_results=5 # Get the last 5 exchanges
)
# Format history for the prompt
context = "\n".join([f"User: {meta['user_message']}\nAI: {doc}" for doc, meta in zip(history['documents'][0], history['metadatas'][0])])
except Exception as e:
print(f"Error querying ChromaDB: {e}")
context = "" # Start fresh if history fails
# Step 2: Construct the prompt with history
prompt_template = (
"<s><|system|>\nYou are a helpful AI assistant. "
"Use the chat history below to provide a relevant and coherent response.\n\n"
"--- Chat History ---\n{chat_history}\n--- End History ---\n<|end|>\n"
"<|user|>\n{user_message}<|end|>\n<|assistant|>"
)
prompt = prompt_template.format(chat_history=context, user_message=request.message)
# Step 3: Generate a response from the LLM
output = llm(prompt=prompt, max_tokens=256, stop=["<|end|>", "User:"], echo=False)
ai_response = output["choices"][0]["text"].strip()
# Step 4: Save the new exchange to ChromaDB
try:
# We store the AI response as the document and the user message in metadata
doc_id = str(uuid.uuid4())
collection.add(
ids=[doc_id],
documents=[ai_response],
metadatas=[{"session_id": request.session_id, "user_message": request.message}]
)
print(f"Saved new exchange to session {request.session_id}")
except Exception as e:
print(f"Error saving to ChromaDB: {e}")
return {"session_id": request.session_id, "response": ai_response}
# --- 7. API Endpoint for YouTube Video Analysis ---
@app.post("/analyze_youtube")
def analyze_youtube_video(request: YouTubeRequest):
try:
# Extract video ID from URL
video_id = request.video_url.split("v=")[1].split("&")[0]
print(f"Fetching transcript for video ID: {video_id}")
# Get transcript
transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
transcript = " ".join([item['text'] for item in transcript_list])
print("Transcript fetched successfully.")
# Create a prompt for summarization
prompt = (
f"<s><|system|>\nYou are an expert analyst. Summarize the key points of the following YouTube video transcript."
f"<|end|>\n<|user|>\nTranscript: {transcript[:3000]}\n\nSummary:<|end|>\n<|assistant|>" # Truncate to fit context
)
# Get summary from LLM
output = llm(prompt, max_tokens=512, stop=["<|end|>"], echo=False)
summary = output["choices"][0]["text"].strip()
return {"video_url": request.video_url, "summary": summary}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def read_root():
return {"status": "API is running."} |