langchain_rag_chat / run_fast_api.py
Stanislav
feat: model downloading style
1f8dc33
from fastapi import FastAPI, Form, Request
from fastapi.responses import HTMLResponse
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi import Cookie
from pydantic import BaseModel
from rag_pipeline.data_loader import CocktailLoader
from rag_pipeline.vector_store import CocktailVectorStore
from rag_pipeline.llm_interface import LocalLLM
from rag_pipeline.user_memory import UserMemory
from rag_pipeline.rag_engine import RAGEngine
from huggingface_hub import hf_hub_download
import os
from uuid import uuid4
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
# ====== CHECK & DOWNLOAD MODEL ======
FILENAME = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
REPO = "stkrk/tinyllama-gguf"
MODEL_DIR = "models"
MODEL_PATH = os.path.join(MODEL_DIR, FILENAME)
os.makedirs(MODEL_DIR, exist_ok=True)
if not os.path.exists(MODEL_PATH):
print(f"Model not found locally. Downloading from {REPO}...")
from huggingface_hub import hf_hub_download
import shutil
cached_path = hf_hub_download(
repo_id=REPO,
filename=FILENAME,
local_dir_use_symlinks=False
)
shutil.copy(cached_path, MODEL_PATH)
print(f"Model downloaded and copied to {MODEL_PATH}.")
else:
print(f"Model already exists at {MODEL_PATH}.")
# ===== 1. Main pipeline =====
print("Initializing RAG pipeline...")
loader = CocktailLoader("data/cocktails.csv")
cocktails = loader.load()
vector_store = CocktailVectorStore(cocktails)
llm = LocalLLM(MODEL_PATH)
memory = UserMemory()
engine = RAGEngine(vector_store, llm, memory)
print("RAG engine is ready!")
# ===== 2. HTML-chat =====
chat_histories = {} # session_id -> list of chat messages
@app.get("/", response_class=HTMLResponse)
async def homepage(request: Request, session_id: str = Cookie(default=None)):
if not session_id:
session_id = str(uuid4())
if session_id not in chat_histories:
chat_histories[session_id] = []
response = templates.TemplateResponse("chat.html", {
"request": request,
"chat_history": chat_histories[session_id]
})
response.set_cookie(key="session_id", value=session_id)
return response
@app.post("/ask", response_class=HTMLResponse)
async def ask(request: Request, message: str = Form(...), session_id: str = Cookie(default=None)):
if not session_id:
session_id = str(uuid4())
if session_id not in chat_histories:
chat_histories[session_id] = []
response = engine.run(message)
chat_histories[session_id].append({"user": message, "bot": response})
html = templates.TemplateResponse("chat.html", {
"request": request,
"chat_history": chat_histories[session_id]
})
html.set_cookie(key="session_id", value=session_id)
return html
@app.get("/static/favicon.png", include_in_schema=False)
async def favicon():
return FileResponse("static/favicon.png")
# ===== 3. REST API endpoint =====
class Query(BaseModel):
message: str
@app.post("/api/ask")
async def api_ask(query: Query):
response = engine.run(query.message)
return {"answer": response}
if __name__ == "__main__":
import uvicorn
uvicorn.run("run_fast_api:app", host="0.0.0.0", port=7860)