Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi import Request | |
| import shutil | |
| import os, glob, json, io, base64 | |
| import torch | |
| import faiss, requests, uvicorn | |
| import numpy as np | |
| import pyarrow.parquet as pq | |
| from PIL import Image | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import ImageDraw, ImageFont | |
| from fastapi.staticfiles import StaticFiles | |
| from inference_vit import MAEEncoder | |
| from inference_vit import get_embedding | |
| from inference_vit import faiss_retrieve | |
| from inference_vit import load_image_by_id, load_image_by_id2 | |
| from datasets import load_dataset | |
| from inference_vit_2 import run_mae_inference | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
| # Mount the static folder | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Serve your index.html | |
| def serve_index(): | |
| return FileResponse("static/index2.html") | |
| def healthcheck(): | |
| return {"status": "ok"} | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load model & FAISS once | |
| print("Loading model and FAISS index...") | |
| ckpt_path = "mae_epoch_new_69.pth" | |
| model = MAEEncoder().to(device) | |
| ckpt = torch.load(ckpt_path, map_location=device) | |
| state_dict = ckpt["model"] | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| # Food101 from Hugging Face | |
| dataset = load_dataset("Multimodal-Fatima/Food101_train") | |
| # index = faiss.read_index("mae_food.index") | |
| INDEX_URL = "https://huggingface.co/musk12/index-embeddings-file-vit/resolve/main/mae_food.index" | |
| image_names = np.load("image_ids.npy") | |
| INDEX_PATH = "mae_food.index" | |
| # Download if not exists | |
| if not os.path.exists(INDEX_PATH): | |
| print("Downloading FAISS index from Hugging Face...") | |
| r = requests.get(INDEX_URL) | |
| with open(INDEX_PATH, "wb") as f: | |
| f.write(r.content) | |
| print("✅ Download complete:", INDEX_PATH) | |
| # Load FAISS index & image ids | |
| index = faiss.read_index(INDEX_PATH) | |
| async def upload_image(file: UploadFile = File(...)): | |
| temp_path = "temp.jpg" | |
| with open(temp_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| # 🔥 get embedding | |
| query_emb = get_embedding(model, temp_path, device) | |
| results = faiss_retrieve(query_emb, index, image_names, top_k=6) | |
| os.remove(temp_path) | |
| results_list = [] | |
| for image_id, score in results: | |
| img, label = load_image_by_id2(image_id) | |
| if img is None: | |
| continue | |
| # convert image to base64 | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| results_list.append({ | |
| "image": f"data:image/jpeg;base64,{img_str}", | |
| "label": label, | |
| "score": round(float(score), 4) | |
| }) | |
| return {"results": results_list} | |
| async def predict(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| result = run_mae_inference(io.BytesIO(contents)) | |
| return { | |
| "original_image": result["original"], | |
| "reconstructed_image": result["reconstructed"], | |
| "mae_output_image": result["mae_output"] | |
| } | |
| # if __name__ == "__main__": | |
| # uvicorn.run("main_api:app", host="0.0.0.0", port=7860, reload=True) |