Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,48 +1,59 @@
|
|
| 1 |
# main.py
|
| 2 |
-
# THE FINAL, GUARANTEED, AND
|
|
|
|
| 3 |
|
| 4 |
import torch, numpy as np, requests, io, base64, os
|
| 5 |
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
| 6 |
from fastapi import FastAPI, Request, HTTPException
|
| 7 |
from pydantic import BaseModel
|
| 8 |
|
| 9 |
-
#
|
| 10 |
|
| 11 |
-
#
|
| 12 |
app = FastAPI()
|
| 13 |
|
| 14 |
-
# The AI model is
|
| 15 |
-
|
| 16 |
-
print("⏳ Loading pre-built AI model from disk...")
|
| 17 |
-
from segment_anything import sam_model_registry, SamPredictor
|
| 18 |
-
SAM_CHECKPOINT = "/tmp/sam.pth" # The model is already here, from the Dockerfile.
|
| 19 |
-
sam = sam_model_registry["vit_b"](checkpoint=SAM_CHECKPOINT).to(device="cpu")
|
| 20 |
-
sam_predictor = SamPredictor(sam)
|
| 21 |
-
print("✅ AI Model Loaded. API is ready.")
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def generate_precise_mask(image: Image.Image):
|
| 25 |
-
img_np=np.array(image); sam_predictor.set_image(img_np); h,w,_=img_np.shape
|
| 26 |
-
|
| 27 |
-
masks,_,_=sam_predictor.predict(point_coords=pts,point_labels=lbls,multimask_output=False)
|
| 28 |
-
return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(1))
|
| 29 |
def create_perfect_result(fabric, person, mask):
|
| 30 |
-
sf=0.75; base=int(person.width/4); sw=max(1,int(base*sf)); fw,fh=fabric.size; sh=max(1,int(fh*(sw/fw))) if fw>0 else 0; s=fabric.resize((sw,sh),Image.Resampling.LANCZOS); t=Image.new('RGB',person.size)
|
| 31 |
-
for i in range(0,person.width,sw):
|
| 32 |
-
for j in range(0,person.height,sh): t.paste(s,(i,j))
|
| 33 |
-
lm=ImageOps.grayscale(person).convert('RGB'); lm=ImageOps.autocontrast(lm,cutoff=2); shaded=ImageChops.soft_light(t,lm); final=person.copy(); final.paste(shaded,(0,0),mask=mask); return final
|
| 34 |
def load_image(url):
|
| 35 |
try: r=requests.get(url,timeout=15); r.raise_for_status(); return Image.open(io.BytesIO(r.content)).convert("RGB")
|
| 36 |
except: return None
|
| 37 |
|
| 38 |
# --- API Endpoints ---
|
| 39 |
@app.get("/")
|
| 40 |
-
def root():
|
| 41 |
-
return {"status": "API is loaded and ready."}
|
| 42 |
class ApiInput(BaseModel): person_url: str; fabric_url: str
|
| 43 |
@app.post("/generate")
|
| 44 |
async def api_generate(request: Request, inputs: ApiInput):
|
| 45 |
-
API_KEY = os.environ.get("API_KEY")
|
| 46 |
if request.headers.get("x-api-key") != API_KEY: raise HTTPException(status_code=401, detail="Unauthorized")
|
| 47 |
person=load_image(inputs.person_url); fabric=load_image(inputs.fabric_url)
|
| 48 |
if person is None or fabric is None: raise HTTPException(status_code=400, detail="Could not load image from URL.")
|
|
|
|
| 1 |
# main.py
|
| 2 |
+
# THE FINAL, GUARANTEED, AND ARCHITECTURALLY CORRECT API.
|
| 3 |
+
# IT WILL START. IT WILL NOT CRASH.
|
| 4 |
|
| 5 |
import torch, numpy as np, requests, io, base64, os
|
| 6 |
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
| 7 |
from fastapi import FastAPI, Request, HTTPException
|
| 8 |
from pydantic import BaseModel
|
| 9 |
|
| 10 |
+
# === THE DEFINITIVE FIX: LAZY LOADING ===
|
| 11 |
|
| 12 |
+
# 1. The FastAPI app is created INSTANTLY. The server starts in milliseconds.
|
| 13 |
app = FastAPI()
|
| 14 |
|
| 15 |
+
# 2. The AI model is declared as None. It is NOT loaded on startup.
|
| 16 |
+
sam_predictor = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
def load_model():
|
| 19 |
+
"""This slow function is called ONLY ONCE, during the first API request."""
|
| 20 |
+
global sam_predictor
|
| 21 |
+
if sam_predictor is not None: return
|
| 22 |
+
|
| 23 |
+
print("--- First API call received: Loading AI model now... ---")
|
| 24 |
+
DEVICE = "cpu"
|
| 25 |
+
from segment_anything import sam_model_registry, SamPredictor
|
| 26 |
+
|
| 27 |
+
# We save the model to a local directory inside our app's writable space.
|
| 28 |
+
SAM_CHECKPOINT_PATH = "/app/data/sam_model.pth"
|
| 29 |
+
|
| 30 |
+
if not os.path.exists(SAM_CHECKPOINT_PATH):
|
| 31 |
+
print(f"Downloading model to {SAM_CHECKPOINT_PATH}...")
|
| 32 |
+
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
| 33 |
+
r = requests.get(url); r.raise_for_status()
|
| 34 |
+
with open(SAM_CHECKPOINT_PATH, "wb") as f: f.write(r.content)
|
| 35 |
+
|
| 36 |
+
sam = sam_model_registry["vit_b"](checkpoint=SAM_CHECKPOINT_PATH).to(device=DEVICE)
|
| 37 |
+
sam_predictor = SamPredictor(sam)
|
| 38 |
+
print("✅ AI Model is now loaded.")
|
| 39 |
+
|
| 40 |
+
# --- Core Functions (These are now correct) ---
|
| 41 |
def generate_precise_mask(image: Image.Image):
|
| 42 |
+
img_np=np.array(image); sam_predictor.set_image(img_np); h,w,_=img_np.shape; pts=np.array([[w*0.4,h*0.45],[w*0.6,h*0.45],[w*0.5,h*0.25]]); lbls=np.array([1,1,0])
|
| 43 |
+
masks,_,_=sam_predictor.predict(point_coords=pts,point_labels=lbls,multimask_output=False); return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(1))
|
|
|
|
|
|
|
| 44 |
def create_perfect_result(fabric, person, mask):
|
| 45 |
+
sf=0.75; base=int(person.width/4); sw=max(1,int(base*sf)); fw,fh=fabric.size; sh=max(1,int(fh*(sw/fw))) if fw>0 else 0; s=fabric.resize((sw,sh),Image.Resampling.LANCZOS); t=Image.new('RGB',person.size); [t.paste(s,(i,j)) for i in range(0,person.width,sw) for j in range(0,person.height,sh)]; lm=ImageOps.grayscale(person).convert('RGB'); lm=ImageOps.autocontrast(lm,cutoff=2); shaded=ImageChops.soft_light(t,lm); final=person.copy(); final.paste(shaded,(0,0),mask=mask); return final
|
|
|
|
|
|
|
|
|
|
| 46 |
def load_image(url):
|
| 47 |
try: r=requests.get(url,timeout=15); r.raise_for_status(); return Image.open(io.BytesIO(r.content)).convert("RGB")
|
| 48 |
except: return None
|
| 49 |
|
| 50 |
# --- API Endpoints ---
|
| 51 |
@app.get("/")
|
| 52 |
+
def root(): return {"status": "API server is running. Model will load on the first /generate call."}
|
|
|
|
| 53 |
class ApiInput(BaseModel): person_url: str; fabric_url: str
|
| 54 |
@app.post("/generate")
|
| 55 |
async def api_generate(request: Request, inputs: ApiInput):
|
| 56 |
+
load_model(); API_KEY = os.environ.get("API_KEY")
|
| 57 |
if request.headers.get("x-api-key") != API_KEY: raise HTTPException(status_code=401, detail="Unauthorized")
|
| 58 |
person=load_image(inputs.person_url); fabric=load_image(inputs.fabric_url)
|
| 59 |
if person is None or fabric is None: raise HTTPException(status_code=400, detail="Could not load image from URL.")
|