Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, UploadFile, Form | |
| from fastapi.responses import RedirectResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import uuid | |
| import shutil | |
| import os | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from pipeline.process_session import process_session_image | |
| from database.db import init_db | |
| from database.crud import get_session, delete_session | |
| from models.sam import SamWrapper | |
| from models.dino import DinoWrapper | |
| from huggingface_hub import hf_hub_download | |
| process_lock = asyncio.Lock() | |
| executor = ThreadPoolExecutor(max_workers=1) | |
| async def run_in_thread(func, *args, **kwargs): | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) | |
| app = FastAPI() | |
| # --- Mount static files and templates | |
| templates = Jinja2Templates(directory="templates") | |
| app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") # still serve static from project root | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Check write permissions | |
| print("WRITE to ./weights:", os.access("weights", os.W_OK)) | |
| print("WRITE to /tmp:", os.access("/tmp", os.W_OK)) | |
| # --- Set base directory depending on environment | |
| BASE_DIR = "/tmp" if os.access("/tmp", os.W_OK) else "." | |
| WEIGHTS_DIR = os.path.join(BASE_DIR, "weights") | |
| # Create directories if not exist | |
| os.makedirs(WEIGHTS_DIR, exist_ok=True) | |
| # --- Initialize database | |
| init_db() | |
| # === Download and load model SAM checkpoint === | |
| FILENAME = "sam_vit_b_01ec64.pth" | |
| REPO_ID = "stkrk/sam-vit-b-checkpoint" | |
| MODEL_PATH = os.path.join(WEIGHTS_DIR, FILENAME) | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"Model not found locally. Downloading from {REPO_ID}...") | |
| cached_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=FILENAME, | |
| cache_dir=WEIGHTS_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| shutil.copy(cached_path, MODEL_PATH) | |
| print(f"Model downloaded and copied to {MODEL_PATH}.") | |
| else: | |
| print(f"Model already exists at {MODEL_PATH}.") | |
| # === Download and prepare Grounding DINO checkpoint === | |
| DINO_REPO_ID = "stkrk/dino_base" | |
| DINO_DIR = os.path.join(WEIGHTS_DIR, "grounding_dino_base") | |
| os.makedirs(DINO_DIR, exist_ok=True) | |
| DINO_FILES = [ | |
| "config.json", | |
| "model.safetensors", | |
| "preprocessor_config.json", | |
| "special_tokens_map.json", | |
| "tokenizer_config.json", | |
| "tokenizer.json", | |
| "vocab.txt" | |
| ] | |
| for filename in DINO_FILES: | |
| target_path = os.path.join(DINO_DIR, filename) | |
| if not os.path.exists(target_path): | |
| print(f"Downloading {filename} from {DINO_REPO_ID}...") | |
| hf_hub_download( | |
| repo_id=DINO_REPO_ID, | |
| filename=filename, | |
| cache_dir=DINO_DIR, | |
| local_dir=DINO_DIR, | |
| local_dir_use_symlinks=False | |
| ) | |
| else: | |
| print(f"{filename} already exists in {DINO_DIR}.") | |
| # --- Initialize models | |
| sam = SamWrapper( | |
| model_type="vit_b", | |
| checkpoint_path=MODEL_PATH | |
| ) | |
| dino = DinoWrapper(model_dir=DINO_DIR) | |
| def index(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| def show_results(request: Request, session_id: str): | |
| session = get_session(session_id) | |
| if not session: | |
| return templates.TemplateResponse("done.html", {"request": request, "message": "Session not found."}) | |
| return templates.TemplateResponse("results.html", { | |
| "request": request, | |
| "session_id": session_id, | |
| "result_paths": session["result_paths"] | |
| }) | |
| async def process_image(request: Request, image: UploadFile = Form(...), prompt: str = Form(...)): | |
| async with process_lock: | |
| # 1. Save uploaded image | |
| session_id = uuid.uuid4().hex | |
| save_dir = "uploads" | |
| os.makedirs(save_dir, exist_ok=True) | |
| image_path = os.path.join(save_dir, f"{session_id}_{image.filename}") | |
| with open(image_path, "wb") as buffer: | |
| shutil.copyfileobj(image.file, buffer) | |
| # 2. Run main pipeline | |
| await run_in_thread( | |
| process_session_image, | |
| session_id=session_id, | |
| image_path=image_path, | |
| prompt_text=prompt, | |
| sam_wrapper=sam, | |
| dino_wrapper=dino | |
| ) | |
| # 3. Redirect to results page | |
| return RedirectResponse(f"/results/{session_id}", status_code=303) | |
| async def finalize_selection( | |
| request: Request, | |
| session_id: str = Form(...), | |
| selected: list[str] = Form(default=[]) | |
| ): | |
| session = get_session(session_id) | |
| if not session: | |
| return templates.TemplateResponse("ready.html", {"request": request, "message": "Session not found."}) | |
| # Remove all the rest of PNGs | |
| for path in session["result_paths"]: | |
| if path not in selected and os.path.exists(path): | |
| os.remove(path) | |
| # Remove all closed session | |
| delete_session(session_id) | |
| return templates.TemplateResponse("ready.html", { | |
| "request": request, | |
| "message": f"Saved {len(selected)} file(s). Session {session_id} closed." | |
| }) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("run_fastapi:app", host="0.0.0.0", port=7860) | |