import os import json from typing import Any, Dict import requests from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import dotenv import gradio as gr dotenv.load_dotenv() # ====================================================== # ENVIRONMENT # ====================================================== HF_TOKEN = os.getenv("HF_TOKEN") AUTH_COOKIE_TOKEN = os.getenv("AUTH_COOKIE_TOKEN") BASE_API = os.getenv("BASE_API", "https://catering-lac.vercel.app") MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" HF_API_URL = f"https://router.huggingface.co/hf-inference/models/{MODEL_NAME}" print("===== ENV CHECK =====") print("HF_TOKEN exists:", HF_TOKEN is not None) print("AUTH_COOKIE_TOKEN exists:", AUTH_COOKIE_TOKEN is not None) print("BASE_API:", BASE_API) print("MODEL_NAME:", MODEL_NAME) print("HF_API_URL:", HF_API_URL) print("======================") # ====================================================== # BACKEND HELPERS # ====================================================== def backend_headers() -> Dict[str, str]: return { "Content-Type": "application/json", "Cookie": f"auth-token={AUTH_COOKIE_TOKEN}" } def http_get(url: str): res = requests.get(url, headers=backend_headers(), timeout=30) if res.status_code >= 400: raise RuntimeError(f"GET FAILED {res.status_code}: {res.text}") return res.json() # ====================================================== # LLM CALL # ====================================================== SYSTEM_PROMPT = """ You generate catering recommendation plans. Rules: - Stay under or close to budget. - Use ONLY menu items provided. - Output VALID JSON only. """ def call_llm(quote, menu_items, categories): payload = { "quote_request": quote, "menu_items": menu_items, "categories": categories } user_prompt = json.dumps(payload, indent=2) body = { "inputs": f"{SYSTEM_PROMPT}\n\nUSER:\n{user_prompt}", "parameters": { "temperature": 0.2, "max_new_tokens": 700 } } headers = { "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json" } response = requests.post(HF_API_URL, headers=headers, json=body, timeout=90) if response.status_code == 404: raise RuntimeError("Model not available on router") data = response.json() if isinstance(data, list): txt = data[0].get("generated_text", "").strip() else: txt = data.get("generated_text", "").strip() if txt.startswith("```"): txt = txt.replace("```json", "").replace("```", "").strip() return json.loads(txt) # ====================================================== # FASTAPI BACKEND # ====================================================== app = FastAPI(title="Catering AI Generator") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] ) class GenerateRequest(BaseModel): quote_id: str class GenerateResponse(BaseModel): quote_id: str plan: Dict[str, Any] persisted: bool @app.post("/generate-plan", response_model=GenerateResponse) def generate_plan_api(body: GenerateRequest): quote = http_get(f"{BASE_API}/api/admin/catering-requests/{body.quote_id}") menu_items = http_get(f"{BASE_API}/api/menu/items") categories = http_get(f"{BASE_API}/api/menu/categories") plan = call_llm(quote, menu_items, categories) return GenerateResponse( quote_id=body.quote_id, plan=plan, persisted=False ) # ====================================================== # GRADIO UI (for HuggingFace) # ====================================================== def gradio_generate(quote_id): try: quote = http_get(f"{BASE_API}/api/admin/catering-requests/{quote_id}") menu_items = http_get(f"{BASE_API}/api/menu/items") categories = http_get(f"{BASE_API}/api/menu/categories") plan = call_llm(quote, menu_items, categories) return plan except Exception as e: return {"error": str(e)} ui = gr.Interface( fn=gradio_generate, inputs=gr.Textbox(label="Enter Quote ID"), outputs="json", title="Catering AI Generator (Llama-2-7B)", description="Generates a catering plan using AI." ) # HuggingFace Spaces runs Gradio by default def main(): ui.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": main()