catering-RAG / app.py
mlbench123's picture
Update app.py
e397a5a verified
import os
import json
from typing import Any, Dict
import requests
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import dotenv
import gradio as gr
dotenv.load_dotenv()
# ======================================================
# ENVIRONMENT
# ======================================================
HF_TOKEN = os.getenv("HF_TOKEN")
AUTH_COOKIE_TOKEN = os.getenv("AUTH_COOKIE_TOKEN")
BASE_API = os.getenv("BASE_API", "https://catering-lac.vercel.app")
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
HF_API_URL = f"https://router.huggingface.co/hf-inference/models/{MODEL_NAME}"
print("===== ENV CHECK =====")
print("HF_TOKEN exists:", HF_TOKEN is not None)
print("AUTH_COOKIE_TOKEN exists:", AUTH_COOKIE_TOKEN is not None)
print("BASE_API:", BASE_API)
print("MODEL_NAME:", MODEL_NAME)
print("HF_API_URL:", HF_API_URL)
print("======================")
# ======================================================
# BACKEND HELPERS
# ======================================================
def backend_headers() -> Dict[str, str]:
return {
"Content-Type": "application/json",
"Cookie": f"auth-token={AUTH_COOKIE_TOKEN}"
}
def http_get(url: str):
res = requests.get(url, headers=backend_headers(), timeout=30)
if res.status_code >= 400:
raise RuntimeError(f"GET FAILED {res.status_code}: {res.text}")
return res.json()
# ======================================================
# LLM CALL
# ======================================================
SYSTEM_PROMPT = """
You generate catering recommendation plans.
Rules:
- Stay under or close to budget.
- Use ONLY menu items provided.
- Output VALID JSON only.
"""
def call_llm(quote, menu_items, categories):
payload = {
"quote_request": quote,
"menu_items": menu_items,
"categories": categories
}
user_prompt = json.dumps(payload, indent=2)
body = {
"inputs": f"{SYSTEM_PROMPT}\n\nUSER:\n{user_prompt}",
"parameters": {
"temperature": 0.2,
"max_new_tokens": 700
}
}
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json"
}
response = requests.post(HF_API_URL, headers=headers, json=body, timeout=90)
if response.status_code == 404:
raise RuntimeError("Model not available on router")
data = response.json()
if isinstance(data, list):
txt = data[0].get("generated_text", "").strip()
else:
txt = data.get("generated_text", "").strip()
if txt.startswith("```"):
txt = txt.replace("```json", "").replace("```", "").strip()
return json.loads(txt)
# ======================================================
# FASTAPI BACKEND
# ======================================================
app = FastAPI(title="Catering AI Generator")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"]
)
class GenerateRequest(BaseModel):
quote_id: str
class GenerateResponse(BaseModel):
quote_id: str
plan: Dict[str, Any]
persisted: bool
@app.post("/generate-plan", response_model=GenerateResponse)
def generate_plan_api(body: GenerateRequest):
quote = http_get(f"{BASE_API}/api/admin/catering-requests/{body.quote_id}")
menu_items = http_get(f"{BASE_API}/api/menu/items")
categories = http_get(f"{BASE_API}/api/menu/categories")
plan = call_llm(quote, menu_items, categories)
return GenerateResponse(
quote_id=body.quote_id,
plan=plan,
persisted=False
)
# ======================================================
# GRADIO UI (for HuggingFace)
# ======================================================
def gradio_generate(quote_id):
try:
quote = http_get(f"{BASE_API}/api/admin/catering-requests/{quote_id}")
menu_items = http_get(f"{BASE_API}/api/menu/items")
categories = http_get(f"{BASE_API}/api/menu/categories")
plan = call_llm(quote, menu_items, categories)
return plan
except Exception as e:
return {"error": str(e)}
ui = gr.Interface(
fn=gradio_generate,
inputs=gr.Textbox(label="Enter Quote ID"),
outputs="json",
title="Catering AI Generator (Llama-2-7B)",
description="Generates a catering plan using AI."
)
# HuggingFace Spaces runs Gradio by default
def main():
ui.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()