kitchenelite-api / main.py
abhinavsunil's picture
Update main.py
d1c9cd0 verified
import os
import re
import pandas as pd
import faiss
import torch
import numpy as np
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from huggingface_hub import snapshot_download
REPO_ID = "abhinavsunil/kitchenelite-recipe-model"
MODEL_CACHE = "/tmp/model_cache"
TOP_K = 5
app = FastAPI(title="KitchenElite Recipe Search API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods (GET, POST, etc.)
allow_headers=["*"], # Allows all headers
)
model = None
index = None
df = None
# ==============================
# STARTUP EVENT
# ==============================
@app.on_event("startup")
def load_assets():
global model, index, df
print("πŸš€ Downloading model repo snapshot...")
local_dir = snapshot_download(
repo_id=REPO_ID,
local_dir=MODEL_CACHE,
local_dir_use_symlinks=False
)
print("πŸ“¦ Loading metadata...")
df = pd.read_parquet(os.path.join(local_dir, "metadata.parquet"))
print("πŸ“¦ Loading FAISS index...")
index = faiss.read_index(os.path.join(local_dir, "recipes.index"))
print("πŸ“¦ Loading SentenceTransformer model...")
model = SentenceTransformer(local_dir, device="cpu")
print("βœ… All assets loaded successfully!")
# ==============================
# UTILITY FUNCTION
# ==============================
def clean_instructions(instruction_input):
if isinstance(instruction_input, str) and instruction_input.startswith('c("'):
content = re.search(r'c\("(.*)"\)', instruction_input)
if content:
return [
step.strip().strip('"')
for step in content.group(1).split('", "')
]
if isinstance(instruction_input, (list, np.ndarray)):
return list(instruction_input)
return [str(instruction_input)]
# ==============================
# ROUTES
# ==============================
@app.get("/")
def home():
return {"status": "KitchenElite API Running πŸš€"}
@app.get("/search")
def search(query: str):
query_vector = model.encode([query])
faiss.normalize_L2(query_vector)
distances, indices = index.search(
query_vector.astype("float32"),
TOP_K
)
results = df.iloc[indices[0]]
output = []
for _, row in results.iterrows():
output.append({
"name": str(row["name"]),
"ingredients": (
list(row["ingredients"])
if isinstance(row["ingredients"], np.ndarray)
else row["ingredients"]
),
"calories": float(row["calories"]),
"protein": float(row["protein"]),
"instructions": clean_instructions(row["RecipeInstructions"])
})
return {
"query": query,
"results": output
}