Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import asyncio | |
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| from app.policy_vector_db import PolicyVectorDB, ensure_db_populated | |
| # ----------------------------- | |
| # β Logging Configuration | |
| # ----------------------------- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger("app") | |
| # ----------------------------- | |
| # β Initialize FastAPI App | |
| # ----------------------------- | |
| app = FastAPI() | |
| async def root(): | |
| return {"status": "β Server is running and ready."} | |
| # ----------------------------- | |
| # β Vector DB and Data Configuration | |
| # ----------------------------- | |
| DB_PERSIST_DIRECTORY = "/app/vector_database" | |
| # This file is generated by the create_granular_chunks.py script in the Dockerfile | |
| CHUNKS_FILE_PATH = "/app/processed_chunks.json" | |
| logger.info("[INFO] Initializing vector DB...") | |
| db = PolicyVectorDB( | |
| persist_directory=DB_PERSIST_DIRECTORY, | |
| top_k_default=5, | |
| relevance_threshold=0.35 # Start with a reasonable threshold for granular chunks | |
| ) | |
| # This function now runs on startup to populate the DB if it's empty | |
| if not ensure_db_populated(db, CHUNKS_FILE_PATH): | |
| logger.warning("[WARNING] DB not populated. Chunks file may be missing or empty. RAG will not function correctly.") | |
| else: | |
| logger.info("[INFO] Vector DB is ready.") | |
| # ----------------------------- | |
| # β Load Your Re-Quantized GGUF Model | |
| # ----------------------------- | |
| # Points to the compatible GGUF file downloaded in the Dockerfile | |
| MODEL_PATH = "/app/phi1.5_dop_q4_k_m.gguf" | |
| logger.info(f"[INFO] Loading GGUF model from: {MODEL_PATH}") | |
| llm = Llama( | |
| model_path=MODEL_PATH, | |
| n_ctx=2048, | |
| n_threads=2, | |
| n_gpu_layers=0, | |
| verbose=False | |
| ) | |
| logger.info("[INFO] Model loaded successfully.") | |
| # ----------------------------- | |
| # β API Schemas | |
| # ----------------------------- | |
| class Query(BaseModel): | |
| question: str | |
| class Feedback(BaseModel): | |
| question: str | |
| answer: str | |
| feedback: str | |
| # ----------------------------- | |
| # β Endpoints | |
| # ----------------------------- | |
| LLM_TIMEOUT_SECONDS = int(os.getenv("LLM_TIMEOUT_SECONDS", "45")) | |
| logger.info(f"[INFO] LLM_TIMEOUT_SECONDS set to: {LLM_TIMEOUT_SECONDS} seconds.") | |
| async def generate_llm_response(prompt: str): | |
| """Helper function to run synchronous LLM inference.""" | |
| response = llm(prompt, max_tokens=384, stop=["Instruct:", "Output:", "###"], temperature=0.2, echo=False) | |
| answer = response["choices"][0]["text"].strip() | |
| if not answer: | |
| raise ValueError("Empty response from LLM") | |
| return answer | |
| async def chat(query: Query): | |
| question = query.question.strip() | |
| logger.info(f"[QUERY] {question}") | |
| search_results = db.search(question) | |
| filtered = sorted( | |
| [r for r in search_results if r["relevance_score"] > db.relevance_threshold], | |
| key=lambda x: x["relevance_score"], | |
| reverse=True | |
| ) | |
| if not filtered: | |
| logger.info("[RESPONSE] No relevant context found.") | |
| return { | |
| "question": question, | |
| "context_used": "No relevant context found.", | |
| "answer": "Sorry, I could not find a relevant policy to answer that question. Please try rephrasing." | |
| } | |
| context = filtered[0]["text"] | |
| logger.info(f"[INFO] Using top context (score: {filtered[0]['relevance_score']:.4f})") | |
| # This prompt format matches how you fine-tuned Phi-1.5 | |
| prompt = f"""Instruct: Use the following context to answer the question. | |
| Context: {context} | |
| Question: {question} | |
| Output:""" | |
| answer = "Sorry, I couldn't process your request right now. Please try again later." | |
| try: | |
| answer = await asyncio.wait_for(generate_llm_response(prompt), timeout=LLM_TIMEOUT_SECONDS) | |
| except asyncio.TimeoutError: | |
| logger.warning(f"[TIMEOUT] LLM generation timed out after {LLM_TIMEOUT_SECONDS} seconds.") | |
| answer = "Sorry, the request took too long to process. Please try again with a simpler question." | |
| except Exception as e: | |
| logger.error(f"[ERROR] An unexpected error occurred during LLM generation: {str(e)}") | |
| answer = "Sorry, an unexpected error occurred while generating a response." | |
| logger.info(f"[RESPONSE] Answered: {answer[:100]}...") | |
| return { | |
| "question": question, | |
| "context_used": context, | |
| "answer": answer | |
| } | |
| async def collect_feedback(feedback: Feedback): | |
| logger.info(f"[FEEDBACK] Question: {feedback.question} | Answer: {feedback.answer} | Feedback: {feedback.feedback}") | |
| return {"status": "β Feedback recorded. Thank you!"} |