Spaces:
Build error
Build error
| import os | |
| import uuid | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from typing import List | |
| # Model and DB libraries | |
| from llama_cpp import Llama | |
| from huggingface_hub import hf_hub_download | |
| import chromadb | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| # --- 1. Constants and Configuration --- | |
| MODEL_REPO = "bartowski/Phi-3.5-mini-instruct_Uncensored-GGUF" | |
| GGUF_FILE = "Phi-3.5-mini-instruct_Uncensored-Q4_K_M.gguf" # Good balance | |
| CHROMA_PATH = "/app/chroma_db" # Path inside the container for persistent storage | |
| COLLECTION_NAME = "chat_history" | |
| # --- 2. Initialize FastAPI app --- | |
| app = FastAPI( | |
| title="Enhanced RAG API with Memory", | |
| description="An API with chat history (ChromaDB) and YouTube analysis.", | |
| version="1.0", | |
| ) | |
| # --- 3. Global Variables (will be loaded on startup) --- | |
| llm: Llama = None | |
| chroma_client: chromadb.Client = None | |
| collection: chromadb.Collection = None | |
| # --- 4. Startup Event: Load models and initialize DB --- | |
| def load_resources(): | |
| global llm, chroma_client, collection | |
| # Load the LLM | |
| print("Downloading and loading LLM...") | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename=GGUF_FILE) | |
| llm = Llama(model_path=model_path, n_ctx=4096, n_gpu_layers=-1, verbose=True) | |
| print("LLM loaded.") | |
| # Initialize ChromaDB client | |
| print("Initializing ChromaDB...") | |
| # This creates a persistent DB client that stores data in the specified path | |
| chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| # Get or create the collection to store chat history | |
| collection = chroma_client.get_or_create_collection(name=COLLECTION_NAME) | |
| print("ChromaDB initialized.") | |
| print("API is ready to go! ๐") | |
| # --- 5. Pydantic Models for API requests --- | |
| class ChatRequest(BaseModel): | |
| session_id: str = Field(..., description="Unique identifier for a chat session.") | |
| message: str = Field(..., description="The user's message.") | |
| class YouTubeRequest(BaseModel): | |
| video_url: str = Field(..., description="URL of the YouTube video to analyze.") | |
| # --- 6. API Endpoint for Chat with Memory --- | |
| def chat_with_memory(request: ChatRequest): | |
| print(f"Received chat request for session: {request.session_id}") | |
| # Step 1: Retrieve relevant chat history from ChromaDB | |
| try: | |
| history = collection.query( | |
| where={"session_id": request.session_id}, | |
| n_results=5 # Get the last 5 exchanges | |
| ) | |
| # Format history for the prompt | |
| context = "\n".join([f"User: {meta['user_message']}\nAI: {doc}" for doc, meta in zip(history['documents'][0], history['metadatas'][0])]) | |
| except Exception as e: | |
| print(f"Error querying ChromaDB: {e}") | |
| context = "" # Start fresh if history fails | |
| # Step 2: Construct the prompt with history | |
| prompt_template = ( | |
| "<s><|system|>\nYou are a helpful AI assistant. " | |
| "Use the chat history below to provide a relevant and coherent response.\n\n" | |
| "--- Chat History ---\n{chat_history}\n--- End History ---\n<|end|>\n" | |
| "<|user|>\n{user_message}<|end|>\n<|assistant|>" | |
| ) | |
| prompt = prompt_template.format(chat_history=context, user_message=request.message) | |
| # Step 3: Generate a response from the LLM | |
| output = llm(prompt=prompt, max_tokens=256, stop=["<|end|>", "User:"], echo=False) | |
| ai_response = output["choices"][0]["text"].strip() | |
| # Step 4: Save the new exchange to ChromaDB | |
| try: | |
| # We store the AI response as the document and the user message in metadata | |
| doc_id = str(uuid.uuid4()) | |
| collection.add( | |
| ids=[doc_id], | |
| documents=[ai_response], | |
| metadatas=[{"session_id": request.session_id, "user_message": request.message}] | |
| ) | |
| print(f"Saved new exchange to session {request.session_id}") | |
| except Exception as e: | |
| print(f"Error saving to ChromaDB: {e}") | |
| return {"session_id": request.session_id, "response": ai_response} | |
| # --- 7. API Endpoint for YouTube Video Analysis --- | |
| def analyze_youtube_video(request: YouTubeRequest): | |
| try: | |
| # Extract video ID from URL | |
| video_id = request.video_url.split("v=")[1].split("&")[0] | |
| print(f"Fetching transcript for video ID: {video_id}") | |
| # Get transcript | |
| transcript_list = YouTubeTranscriptApi.get_transcript(video_id) | |
| transcript = " ".join([item['text'] for item in transcript_list]) | |
| print("Transcript fetched successfully.") | |
| # Create a prompt for summarization | |
| prompt = ( | |
| f"<s><|system|>\nYou are an expert analyst. Summarize the key points of the following YouTube video transcript." | |
| f"<|end|>\n<|user|>\nTranscript: {transcript[:3000]}\n\nSummary:<|end|>\n<|assistant|>" # Truncate to fit context | |
| ) | |
| # Get summary from LLM | |
| output = llm(prompt, max_tokens=512, stop=["<|end|>"], echo=False) | |
| summary = output["choices"][0]["text"].strip() | |
| return {"video_url": request.video_url, "summary": summary} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def read_root(): | |
| return {"status": "API is running."} |