musk12's picture
Update app.py
4a66b6e verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi import Request
import shutil
import os, glob, json, io, base64
import torch
import faiss, requests, uvicorn
import numpy as np
import pyarrow.parquet as pq
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
from PIL import ImageDraw, ImageFont
from fastapi.staticfiles import StaticFiles
from inference_vit import MAEEncoder
from inference_vit import get_embedding
from inference_vit import faiss_retrieve
from inference_vit import load_image_by_id, load_image_by_id2
from datasets import load_dataset
from inference_vit_2 import run_mae_inference
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# app.mount("/", StaticFiles(directory="static", html=True), name="static")
# Mount the static folder
app.mount("/static", StaticFiles(directory="static"), name="static")
# Serve your index.html
@app.get("/")
def serve_index():
return FileResponse("static/index2.html")
@app.get("/health")
def healthcheck():
return {"status": "ok"}
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model & FAISS once
print("Loading model and FAISS index...")
ckpt_path = "mae_epoch_new_69.pth"
model = MAEEncoder().to(device)
ckpt = torch.load(ckpt_path, map_location=device)
state_dict = ckpt["model"]
model.load_state_dict(state_dict, strict=False)
model.eval()
# Food101 from Hugging Face
dataset = load_dataset("Multimodal-Fatima/Food101_train")
# index = faiss.read_index("mae_food.index")
INDEX_URL = "https://huggingface.co/musk12/index-embeddings-file-vit/resolve/main/mae_food.index"
image_names = np.load("image_ids.npy")
INDEX_PATH = "mae_food.index"
# Download if not exists
if not os.path.exists(INDEX_PATH):
print("Downloading FAISS index from Hugging Face...")
r = requests.get(INDEX_URL)
with open(INDEX_PATH, "wb") as f:
f.write(r.content)
print("✅ Download complete:", INDEX_PATH)
# Load FAISS index & image ids
index = faiss.read_index(INDEX_PATH)
@app.post("/upload")
async def upload_image(file: UploadFile = File(...)):
temp_path = "temp.jpg"
with open(temp_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# 🔥 get embedding
query_emb = get_embedding(model, temp_path, device)
results = faiss_retrieve(query_emb, index, image_names, top_k=6)
os.remove(temp_path)
results_list = []
for image_id, score in results:
img, label = load_image_by_id2(image_id)
if img is None:
continue
# convert image to base64
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
results_list.append({
"image": f"data:image/jpeg;base64,{img_str}",
"label": label,
"score": round(float(score), 4)
})
return {"results": results_list}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
result = run_mae_inference(io.BytesIO(contents))
return {
"original_image": result["original"],
"reconstructed_image": result["reconstructed"],
"mae_output_image": result["mae_output"]
}
# if __name__ == "__main__":
# uvicorn.run("main_api:app", host="0.0.0.0", port=7860, reload=True)