import os import sys import base64 import urllib.parse from datetime import datetime from typing import Any from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, Response from pydantic import BaseModel, Field import uvicorn # --- Core dependencies --- try: from llama_cpp import Llama print("✅ llama-cpp-python") except ImportError: print("❌ Run: pip install llama-cpp-python") sys.exit(1) from rag_utils import ( ABSTAIN_MESSAGE, build_general_system_prompt, build_hybrid_system_prompt, build_system_prompt, compose_krce_response, finalize_general_response, finalize_krce_response, load_rag_index, search_krce, ) # --- Config --- # Model settings REPO_ID = "Krishkanth/krish-mind-mobile" MODEL_FILENAME = "krish-mind-mobile.gguf" BASE_DIR = os.path.dirname(__file__) STATIC_DIR = os.path.join(BASE_DIR, "static") LOGO_B64_FILE = os.path.join(STATIC_DIR, "logo_png_base64.txt") default_clean_data = os.path.join(BASE_DIR, "data", "krce_college_data_clean.jsonl") default_legacy_data = os.path.join(BASE_DIR, "data", "krce_college_data.jsonl") DATA_FILE = default_clean_data if os.path.exists(default_clean_data) else default_legacy_data _logo_png_cache: bytes | None = None # --- Load GGUF Model --- print(f"\n⏳ Downloading/Loading model from {REPO_ID}...") try: from huggingface_hub import hf_hub_download # Download model (cached) model_path = hf_hub_download( repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir="model", # Download to local folder local_dir_use_symlinks=False ) print(f"✅ Model downloaded to: {model_path}") model = Llama( model_path=model_path, n_ctx=4096, n_gpu_layers=0, # CPU only for free tier verbose=False ) print("✅ Model loaded!") except Exception as e: print(f"❌ Model error: {e}") model = None # --- RAG SETUP --- print("📚 Indexing Knowledge Base...") rag_index = load_rag_index(DATA_FILE) if rag_index.model is not None and rag_index.records: print(f"✅ Indexed {len(rag_index.records)} KRCE facts.") else: print("⚠️ Data file not found or embedding model unavailable. RAG disabled.") # --- FastAPI --- app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) # Serve Static Files app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") class ChatRequest(BaseModel): message: str max_tokens: int = 1024 temperature: float = 0.1 krce_mode: bool = False history: list[dict[str, Any]] = Field(default_factory=list) @app.get("/") async def root(): # Serve index.html at root return FileResponse(os.path.join(STATIC_DIR, "index.html")) @app.get("/logo.png") async def logo(): global _logo_png_cache if _logo_png_cache is None: if os.path.exists(LOGO_B64_FILE): with open(LOGO_B64_FILE, "r", encoding="ascii") as handle: _logo_png_cache = base64.b64decode(handle.read().strip()) else: return FileResponse(os.path.join(STATIC_DIR, "logo.svg"), media_type="image/svg+xml") return Response(content=_logo_png_cache, media_type="image/png") @app.get("/logo.svg") async def logo_svg(): return FileResponse(os.path.join(STATIC_DIR, "logo.svg"), media_type="image/svg+xml") @app.post("/chat") async def chat(request: ChatRequest): if not model: return {"response": "Error: Model not loaded. Please check server logs."} user_input = request.message # Image Generation Hook if any(t in user_input.lower() for t in ["generate image", "create image", "draw", "imagine"]): prompt = user_input.replace("generate image", "").strip() url = f"https://image.pollinations.ai/prompt/{urllib.parse.quote(prompt)}" return {"response": f"Here's your image of **{prompt}**:\n\n![{prompt}]({url})"} # Frontend controls route explicitly: # - KRCE mode ON: strict grounded KRCE answers only # - KRCE mode OFF: normal model chat without RAG retrieval route = "krce" if bool(request.krce_mode) else "general" rag_result = { "context": "", "hits": [], "should_abstain": False, "confidence": 0.0, } if route in {"krce", "hybrid"}: rag_result = search_krce(user_input, rag_index) if rag_result["context"]: print(f"\n[📦 RAG CONTEXT FOUND]\n{rag_result['context']}\n") if route == "krce" and rag_result["should_abstain"]: return {"response": ABSTAIN_MESSAGE} if route == "krce" and rag_result.get("hits"): response_text = compose_krce_response(user_input, rag_result) return {"response": finalize_krce_response(user_input, response_text, rag_result)} now = datetime.now().strftime("%A, %B %d, %Y") if route == "hybrid": sys_prompt = build_hybrid_system_prompt(now, rag_result) elif route == "general": sys_prompt = build_general_system_prompt(now) else: sys_prompt = build_system_prompt(now, user_input, rag_result) prompt_text = user_input if route == "general" and request.history: compact_turns: list[str] = [] for turn in request.history[-8:]: role = str(turn.get("role", "")).strip().lower() content = str(turn.get("content", "")).strip() if role not in {"user", "assistant"} or not content: continue if len(content) > 1200: content = content[:1200].rstrip() + " ..." speaker = "User" if role == "user" else "Assistant" compact_turns.append(f"{speaker}: {content}") if compact_turns: prompt_text = ( "Conversation context (most recent turns):\n" + "\n".join(compact_turns) + "\n\nUser: " + user_input + "\nAssistant:" ) full_prompt = f"<|system|>\n{sys_prompt}<|end|>\n<|user|>\n{prompt_text}<|end|>\n<|assistant|>\n" # Enforce strict stop tokens to prevent the model from hallucinating user prompts or looping stop_tokens = ["<|end|>", "<|endoftext|>", "<|user|>", "<|system|>"] try: max_allowed = 420 if route == "krce" else 1200 effective_tokens = max(64, min(int(request.max_tokens), max_allowed)) effective_temp = min(request.temperature, 0.1) if route == "krce" else min(max(request.temperature, 0.2), 0.6) output = model( full_prompt, max_tokens=effective_tokens, temperature=effective_temp, repeat_penalty=1.15, # Prevents text repeating/gibberish loops stop=stop_tokens, echo=False ) response_text = output["choices"][0]["text"].strip() finish_reason = str(output["choices"][0].get("finish_reason", "")).lower() if route == "general" and finish_reason == "length" and response_text: continue_prompt = ( f"{full_prompt}{response_text}\n" "Continue from where it stopped. Do not repeat previous lines. " "Finish the answer clearly." ) cont = model( continue_prompt, max_tokens=min(400, max_allowed), temperature=max(0.15, min(effective_temp, 0.4)), repeat_penalty=1.12, stop=stop_tokens, echo=False, ) extra = cont["choices"][0]["text"].strip() if extra: response_text = (response_text + "\n" + extra).strip() if route == "krce": return {"response": finalize_krce_response(user_input, response_text, rag_result)} return {"response": finalize_general_response(user_input, response_text)} except Exception as e: return {"response": f"Error: {e}"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)