Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import base64 | |
| import urllib.parse | |
| from datetime import datetime | |
| from typing import Any | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, Response | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| # --- Core dependencies --- | |
| try: | |
| from llama_cpp import Llama | |
| print("β llama-cpp-python") | |
| except ImportError: | |
| print("β Run: pip install llama-cpp-python") | |
| sys.exit(1) | |
| from rag_utils import ( | |
| ABSTAIN_MESSAGE, | |
| build_general_system_prompt, | |
| build_hybrid_system_prompt, | |
| build_system_prompt, | |
| compose_krce_response, | |
| finalize_general_response, | |
| finalize_krce_response, | |
| load_rag_index, | |
| search_krce, | |
| ) | |
| # --- Config --- | |
| # Model settings | |
| REPO_ID = "Krishkanth/krish-mind-mobile" | |
| MODEL_FILENAME = "krish-mind-mobile.gguf" | |
| BASE_DIR = os.path.dirname(__file__) | |
| STATIC_DIR = os.path.join(BASE_DIR, "static") | |
| LOGO_B64_FILE = os.path.join(STATIC_DIR, "logo_png_base64.txt") | |
| default_clean_data = os.path.join(BASE_DIR, "data", "krce_college_data_clean.jsonl") | |
| default_legacy_data = os.path.join(BASE_DIR, "data", "krce_college_data.jsonl") | |
| DATA_FILE = default_clean_data if os.path.exists(default_clean_data) else default_legacy_data | |
| _logo_png_cache: bytes | None = None | |
| # --- Load GGUF Model --- | |
| print(f"\nβ³ Downloading/Loading model from {REPO_ID}...") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| # Download model (cached) | |
| model_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=MODEL_FILENAME, | |
| local_dir="model", # Download to local folder | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"β Model downloaded to: {model_path}") | |
| model = Llama( | |
| model_path=model_path, | |
| n_ctx=4096, | |
| n_gpu_layers=0, # CPU only for free tier | |
| verbose=False | |
| ) | |
| print("β Model loaded!") | |
| except Exception as e: | |
| print(f"β Model error: {e}") | |
| model = None | |
| # --- RAG SETUP --- | |
| print("π Indexing Knowledge Base...") | |
| rag_index = load_rag_index(DATA_FILE) | |
| if rag_index.model is not None and rag_index.records: | |
| print(f"β Indexed {len(rag_index.records)} KRCE facts.") | |
| else: | |
| print("β οΈ Data file not found or embedding model unavailable. RAG disabled.") | |
| # --- FastAPI --- | |
| app = FastAPI() | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| # Serve Static Files | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| class ChatRequest(BaseModel): | |
| message: str | |
| max_tokens: int = 1024 | |
| temperature: float = 0.1 | |
| krce_mode: bool = False | |
| history: list[dict[str, Any]] = Field(default_factory=list) | |
| async def root(): | |
| # Serve index.html at root | |
| return FileResponse(os.path.join(STATIC_DIR, "index.html")) | |
| async def logo(): | |
| global _logo_png_cache | |
| if _logo_png_cache is None: | |
| if os.path.exists(LOGO_B64_FILE): | |
| with open(LOGO_B64_FILE, "r", encoding="ascii") as handle: | |
| _logo_png_cache = base64.b64decode(handle.read().strip()) | |
| else: | |
| return FileResponse(os.path.join(STATIC_DIR, "logo.svg"), media_type="image/svg+xml") | |
| return Response(content=_logo_png_cache, media_type="image/png") | |
| async def logo_svg(): | |
| return FileResponse(os.path.join(STATIC_DIR, "logo.svg"), media_type="image/svg+xml") | |
| async def chat(request: ChatRequest): | |
| if not model: | |
| return {"response": "Error: Model not loaded. Please check server logs."} | |
| user_input = request.message | |
| # Image Generation Hook | |
| if any(t in user_input.lower() for t in ["generate image", "create image", "draw", "imagine"]): | |
| prompt = user_input.replace("generate image", "").strip() | |
| url = f"https://image.pollinations.ai/prompt/{urllib.parse.quote(prompt)}" | |
| return {"response": f"Here's your image of **{prompt}**:\n\n"} | |
| # Frontend controls route explicitly: | |
| # - KRCE mode ON: strict grounded KRCE answers only | |
| # - KRCE mode OFF: normal model chat without RAG retrieval | |
| route = "krce" if bool(request.krce_mode) else "general" | |
| rag_result = { | |
| "context": "", | |
| "hits": [], | |
| "should_abstain": False, | |
| "confidence": 0.0, | |
| } | |
| if route in {"krce", "hybrid"}: | |
| rag_result = search_krce(user_input, rag_index) | |
| if rag_result["context"]: | |
| print(f"\n[π¦ RAG CONTEXT FOUND]\n{rag_result['context']}\n") | |
| if route == "krce" and rag_result["should_abstain"]: | |
| return {"response": ABSTAIN_MESSAGE} | |
| if route == "krce" and rag_result.get("hits"): | |
| response_text = compose_krce_response(user_input, rag_result) | |
| return {"response": finalize_krce_response(user_input, response_text, rag_result)} | |
| now = datetime.now().strftime("%A, %B %d, %Y") | |
| if route == "hybrid": | |
| sys_prompt = build_hybrid_system_prompt(now, rag_result) | |
| elif route == "general": | |
| sys_prompt = build_general_system_prompt(now) | |
| else: | |
| sys_prompt = build_system_prompt(now, user_input, rag_result) | |
| prompt_text = user_input | |
| if route == "general" and request.history: | |
| compact_turns: list[str] = [] | |
| for turn in request.history[-8:]: | |
| role = str(turn.get("role", "")).strip().lower() | |
| content = str(turn.get("content", "")).strip() | |
| if role not in {"user", "assistant"} or not content: | |
| continue | |
| if len(content) > 1200: | |
| content = content[:1200].rstrip() + " ..." | |
| speaker = "User" if role == "user" else "Assistant" | |
| compact_turns.append(f"{speaker}: {content}") | |
| if compact_turns: | |
| prompt_text = ( | |
| "Conversation context (most recent turns):\n" | |
| + "\n".join(compact_turns) | |
| + "\n\nUser: " | |
| + user_input | |
| + "\nAssistant:" | |
| ) | |
| full_prompt = f"<|system|>\n{sys_prompt}<|end|>\n<|user|>\n{prompt_text}<|end|>\n<|assistant|>\n" | |
| # Enforce strict stop tokens to prevent the model from hallucinating user prompts or looping | |
| stop_tokens = ["<|end|>", "<|endoftext|>", "<|user|>", "<|system|>"] | |
| try: | |
| max_allowed = 420 if route == "krce" else 1200 | |
| effective_tokens = max(64, min(int(request.max_tokens), max_allowed)) | |
| effective_temp = min(request.temperature, 0.1) if route == "krce" else min(max(request.temperature, 0.2), 0.6) | |
| output = model( | |
| full_prompt, | |
| max_tokens=effective_tokens, | |
| temperature=effective_temp, | |
| repeat_penalty=1.15, # Prevents text repeating/gibberish loops | |
| stop=stop_tokens, | |
| echo=False | |
| ) | |
| response_text = output["choices"][0]["text"].strip() | |
| finish_reason = str(output["choices"][0].get("finish_reason", "")).lower() | |
| if route == "general" and finish_reason == "length" and response_text: | |
| continue_prompt = ( | |
| f"{full_prompt}{response_text}\n" | |
| "Continue from where it stopped. Do not repeat previous lines. " | |
| "Finish the answer clearly." | |
| ) | |
| cont = model( | |
| continue_prompt, | |
| max_tokens=min(400, max_allowed), | |
| temperature=max(0.15, min(effective_temp, 0.4)), | |
| repeat_penalty=1.12, | |
| stop=stop_tokens, | |
| echo=False, | |
| ) | |
| extra = cont["choices"][0]["text"].strip() | |
| if extra: | |
| response_text = (response_text + "\n" + extra).strip() | |
| if route == "krce": | |
| return {"response": finalize_krce_response(user_input, response_text, rag_result)} | |
| return {"response": finalize_general_response(user_input, response_text)} | |
| except Exception as e: | |
| return {"response": f"Error: {e}"} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |