|
|
from backend.utils.request_dto.chat_response import ChatResponse |
|
|
from backend.utils.request_dto.scrape_request import ScrapeRequest |
|
|
from backend.utils.types import ChatMessage |
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Header |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import os |
|
|
from typing import Type |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from data_minning.dto.stream_opts import StreamOptions |
|
|
from data_minning.base_scrapper import BaseRecipeScraper, JsonArraySink, MongoSink |
|
|
from data_minning.all_nigerian_recipe_scraper import AllNigerianRecipesScraper |
|
|
from data_minning.yummy_medley_scraper import YummyMedleyScraper |
|
|
from backend.config.settings import settings |
|
|
from backend.config.logging_config import setup_default_logging, get_logger |
|
|
from backend.utils.sanitization import sanitize_user_input |
|
|
from backend.services.vector_store import vector_store_service |
|
|
|
|
|
setup_default_logging() |
|
|
logger = get_logger("app") |
|
|
|
|
|
|
|
|
from backend.services.llm_service import llm_service |
|
|
|
|
|
SCRAPERS: dict[str, Type[BaseRecipeScraper]] = { |
|
|
"yummy": YummyMedleyScraper, |
|
|
"anr": AllNigerianRecipesScraper, |
|
|
} |
|
|
|
|
|
app = FastAPI( |
|
|
title="Recipe Recommendation Bot API", |
|
|
description="AI-powered recipe recommendation system with RAG capabilities", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
logger.info("🚀 Starting Recipe Recommendation Bot API") |
|
|
logger.info(f"Environment: {settings.ENVIRONMENT}") |
|
|
logger.info(f"Provider: {settings.get_llm_config()['provider']} (LLM + Embeddings)") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=settings.CORS_ORIGINS or ["*"], |
|
|
allow_credentials=settings.CORS_ALLOW_CREDENTIALS or True, |
|
|
allow_methods=settings.CORS_ALLOW_METHODS or ["*"], |
|
|
allow_headers=settings.CORS_ALLOW_HEADERS or ["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def index(): |
|
|
logger.info("📡 Root endpoint accessed") |
|
|
return { |
|
|
"message": "Recipe Recommendation Bot API", |
|
|
"version": "1.0.0", |
|
|
"status": "running" |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
logger.info("🏥 Health check endpoint accessed") |
|
|
return { |
|
|
"status": "healthy", |
|
|
"environment": settings.ENVIRONMENT, |
|
|
"llm_service_initialized": llm_service is not None |
|
|
} |
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
async def chat(chat_message: ChatMessage): |
|
|
"""Main chatbot endpoint - Recipe recommendation with ConversationalRetrievalChain""" |
|
|
try: |
|
|
|
|
|
|
|
|
last_user_message = chat_message.get_latest_message() |
|
|
if not last_user_message: |
|
|
raise ValueError("No valid user message found") |
|
|
user_text = last_user_message.parts[0].text |
|
|
|
|
|
response_text = llm_service.ask_question(user_text) |
|
|
return ChatResponse(response=response_text) |
|
|
|
|
|
except ValueError as e: |
|
|
|
|
|
logger.warning(f"⚠️ Invalid input received: {str(e)}") |
|
|
raise HTTPException(status_code=400, detail=f"Invalid input: {str(e)}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Chat service error: {str(e)}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=f"Chat service error: {str(e)}") |
|
|
|
|
|
@app.get("/demo") |
|
|
def demo(prompt: str = "What recipes do you have?"): |
|
|
"""Demo endpoint - uses simple chat completion without RAG""" |
|
|
logger.info(f"🎯 Demo request: '{prompt[:50]}...'") |
|
|
|
|
|
try: |
|
|
|
|
|
sanitized_prompt = sanitize_user_input(prompt) |
|
|
response_text = llm_service.simple_chat_completion(sanitized_prompt) |
|
|
return {"prompt": sanitized_prompt, "reply": response_text} |
|
|
|
|
|
except ValueError as e: |
|
|
|
|
|
logger.warning(f"⚠️ Invalid demo prompt: {str(e)}") |
|
|
return {"error": f"Invalid prompt: {str(e)}", "prompt": prompt} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Demo endpoint error: {str(e)}", exc_info=True) |
|
|
return {"error": f"Failed to get response: {str(e)}"} |
|
|
|
|
|
@app.post("/clear-memory") |
|
|
def clear_conversation_memory(): |
|
|
"""Clear conversation memory""" |
|
|
logger.info("🧹 Memory clear request received") |
|
|
|
|
|
try: |
|
|
success = llm_service.clear_memory() |
|
|
|
|
|
if success: |
|
|
logger.info("✅ Conversation memory cleared successfully") |
|
|
return {"status": "success", "message": "Conversation memory cleared"} |
|
|
else: |
|
|
logger.warning("⚠️ Memory clear operation failed") |
|
|
return {"status": "failed", "message": "Failed to clear conversation memory"} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Memory clear error: {str(e)}", exc_info=True) |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
def run_job(job_id: str, site: str, limit: int, output_type: str): |
|
|
''' |
|
|
Background job to run the scraper |
|
|
Uses global JOBS dict to track status |
|
|
Outputs to JSON file or MongoDB based on output_type |
|
|
''' |
|
|
s = SCRAPERS[site]() |
|
|
s.embedder = vector_store_service._create_sentence_transformer_wrapper("sentence-transformers/all-MiniLM-L6-v2") |
|
|
s.embedding_fields = [(("title", "ingredients", "instructions"), "recipe_emb")] |
|
|
sink = None |
|
|
if output_type == "json": |
|
|
sink = JsonArraySink("./data/recipes_unified.json") |
|
|
elif output_type == "mongo": |
|
|
sink = MongoSink() if os.getenv("MONGODB_URI") else None |
|
|
|
|
|
stream_opts = StreamOptions( |
|
|
delay=0.3, |
|
|
limit=500, |
|
|
batch_size=limit, |
|
|
resume_file="recipes.resume", |
|
|
progress_callback=make_progress_cb(job_id), |
|
|
) |
|
|
try: |
|
|
JOBS[job_id] = {"status": "running", "count": 0} |
|
|
s.stream( sink=sink, options=stream_opts) |
|
|
JOBS[job_id]["status"] = "done" |
|
|
except Exception as e: |
|
|
JOBS[job_id] = {"status": "error", "error": str(e)} |
|
|
|
|
|
def make_progress_cb(job_id: str): |
|
|
''' Create a progress callback to update JOBS dict |
|
|
''' |
|
|
def _cb(n: int): |
|
|
JOBS[job_id]["count"] = n |
|
|
return _cb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
JOBS: dict[str, any] = {} |
|
|
|
|
|
@app.post("/scrape") |
|
|
def scrape(body: ScrapeRequest, background: BackgroundTasks, x_api_key: str = Header(None)): |
|
|
if body.site not in SCRAPERS: |
|
|
raise HTTPException(status_code=400, detail="Unknown site") |
|
|
|
|
|
job_id = f"{body.site}-{os.urandom(4).hex()}" |
|
|
|
|
|
background.add_task(run_job, job_id, body.site, body.limit, body.output_type) |
|
|
return {"job_id": job_id, "status": "queued"} |
|
|
|
|
|
@app.get("/jobs/{job_id}") |
|
|
def job_status(job_id: str): |
|
|
return JOBS.get(job_id, {"status": "unknown"}) |
|
|
|
|
|
@app.get("/jobs") |
|
|
def list_jobs(): |
|
|
return JOBS |
|
|
|