from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi import Request import shutil import os, glob, json, io, base64 import torch import faiss, requests, uvicorn import numpy as np import pyarrow.parquet as pq from PIL import Image from fastapi.middleware.cors import CORSMiddleware from PIL import ImageDraw, ImageFont from fastapi.staticfiles import StaticFiles from inference_vit import MAEEncoder from inference_vit import get_embedding from inference_vit import faiss_retrieve from inference_vit import load_image_by_id, load_image_by_id2 from datasets import load_dataset from inference_vit_2 import run_mae_inference app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # app.mount("/", StaticFiles(directory="static", html=True), name="static") # Mount the static folder app.mount("/static", StaticFiles(directory="static"), name="static") # Serve your index.html @app.get("/") def serve_index(): return FileResponse("static/index2.html") @app.get("/health") def healthcheck(): return {"status": "ok"} device = "cuda" if torch.cuda.is_available() else "cpu" # Load model & FAISS once print("Loading model and FAISS index...") ckpt_path = "mae_epoch_new_69.pth" model = MAEEncoder().to(device) ckpt = torch.load(ckpt_path, map_location=device) state_dict = ckpt["model"] model.load_state_dict(state_dict, strict=False) model.eval() # Food101 from Hugging Face dataset = load_dataset("Multimodal-Fatima/Food101_train") # index = faiss.read_index("mae_food.index") INDEX_URL = "https://huggingface.co/musk12/index-embeddings-file-vit/resolve/main/mae_food.index" image_names = np.load("image_ids.npy") INDEX_PATH = "mae_food.index" # Download if not exists if not os.path.exists(INDEX_PATH): print("Downloading FAISS index from Hugging Face...") r = requests.get(INDEX_URL) with open(INDEX_PATH, "wb") as f: f.write(r.content) print("✅ Download complete:", INDEX_PATH) # Load FAISS index & image ids index = faiss.read_index(INDEX_PATH) @app.post("/upload") async def upload_image(file: UploadFile = File(...)): temp_path = "temp.jpg" with open(temp_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) # 🔥 get embedding query_emb = get_embedding(model, temp_path, device) results = faiss_retrieve(query_emb, index, image_names, top_k=6) os.remove(temp_path) results_list = [] for image_id, score in results: img, label = load_image_by_id2(image_id) if img is None: continue # convert image to base64 buffered = io.BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() results_list.append({ "image": f"data:image/jpeg;base64,{img_str}", "label": label, "score": round(float(score), 4) }) return {"results": results_list} @app.post("/predict") async def predict(file: UploadFile = File(...)): contents = await file.read() result = run_mae_inference(io.BytesIO(contents)) return { "original_image": result["original"], "reconstructed_image": result["reconstructed"], "mae_output_image": result["mae_output"] } # if __name__ == "__main__": # uvicorn.run("main_api:app", host="0.0.0.0", port=7860, reload=True)