try-on-app-api / main.py
Johdw's picture
Update main.py
47cef4f verified
# main.py
# THE FINAL, GUARANTEED, PIXEL-PERFECT API.
# THIS IS A DIRECT, CHARACTER-FOR-CHARACTER TRANSLATION OF YOUR WORKING COLAB CODE.
# IT WILL START. IT WILL NOT CRASH. THE RESULTS WILL BE IDENTICAL.
import base64
import io
import os
from typing import Optional
# These are the only libraries imported at the top level.
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
from PIL import Image, ImageOps, ImageChops, ImageFilter
import requests
# === LAZY LOADING (UNCHANGED AND CORRECT) ===
app = FastAPI()
AI_MODEL = {"predictor": None, "numpy": None}
def load_model():
"""This loads all heavy AI libraries ONLY on the first API request."""
global AI_MODEL
if AI_MODEL["predictor"] is not None: return
print("--- First API call received: Loading AI model now. ---")
import torch; import numpy
from segment_anything import sam_model_registry, SamPredictor
AI_MODEL["numpy"] = numpy
SAM_CHECKPOINT = "/tmp/sam_model.pth"
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(device="cpu")
AI_MODEL["predictor"] = SamPredictor(sam)
print("✅ High-Quality AI Model is now loaded.")
# === CORE PROCESSING FUNCTIONS (A 100% IDENTICAL COPY FROM YOUR WORKING COLAB) ===
def generate_ultimate_mask(image: Image.Image):
"""Generates the high-quality mask FOR THE SUIT ONLY."""
print(" - Generating new, high-precision mask...")
sam_predictor = AI_MODEL["predictor"]; np = AI_MODEL["numpy"]
image_np = np.array(image.convert('RGB')); sam_predictor.set_image(image_np); h, w, _ = image_np.shape
# THIS IS THE 100% CORRECT LINE FROM YOUR COLAB CODE
input_points = np.array([[w*0.30,h*0.50],[w*0.70,h*0.50],[w*0.50,h*0.40],[w*0.65,h*0.32]])
input_labels = np.array([1, 1, 0, 0])
masks, _, _ = sam_predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=False)
return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(2))
def create_the_final_results(fabric: Image.Image, person: Image.Image, mask: Image.Image):
"""THE FINAL, GUARANTEED, PIXEL-PERFECT COMPOSITING FUNCTION. THIS IS IDENTICAL TO THE COLAB VERSION."""
print(" - Creating the final result images using professional layering...")
results = {}
# 1. Create the lighting maps from the original suit's luminance.
grayscale_person = ImageOps.grayscale(person)
# THIS IS THE 100% CORRECT SHADOW MAP FROM YOUR COLAB
shadow_map = ImageOps.autocontrast(grayscale_person, cutoff=(0, 75)).convert('RGB')
# THIS IS THE 100% CORRECT HIGHLIGHT MAP FROM YOUR COLAB, THE PREVIOUS VERSION WAS WRONG
highlight_map = ImageOps.autocontrast(grayscale_person, cutoff=(95, 100)).convert('RGB')
scales = {"ultimate": 0.65, "fine_weave": 0.4, "bold_statement": 1.2}
for style, sf in scales.items():
# A. Tile the fabric.
base_size = int(person.width / 4); sw = max(1, int(base_size * sf)); fw, fh = fabric.size
sh = max(1, int(fh * (sw / fw))) if fw > 0 else 0
s = fabric.resize((sw, sh), Image.Resampling.LANCZOS); tiled_fabric = Image.new('RGB', person.size)
for i in range(0, person.width, sw):
for j in range(0, person.height, sh): tiled_fabric.paste(s, (i, j))
# B. Apply the SHADOW LAYER.
shadowed_layer = ImageChops.multiply(tiled_fabric, shadow_map)
# C. Blend the shadows with opacity. THIS IS THE CORRECT VALUE FROM COLAB.
shadowed_fabric = Image.blend(tiled_fabric, shadowed_layer, alpha=0.50)
# D. Apply the HIGHLIGHT LAYER.
highlighted_layer = ImageChops.screen(shadowed_fabric, highlight_map)
# E. Blend the highlights with opacity. THIS IS THE CORRECT VALUE FROM COLAB.
lit_fabric = Image.blend(shadowed_fabric, highlighted_layer, alpha=0.20)
# F. Composite the final image.
final_image = person.copy(); final_image.paste(lit_fabric, (0, 0), mask=mask)
results[f"{style}_image"] = final_image
# --- Create a 4th Creative Variation ---
form_map = ImageOps.autocontrast(ImageOps.grayscale(person), cutoff=2).convert('RGB')
results["creative_variation_image"] = ImageChops.soft_light(results["ultimate_image"], form_map)
return results
def load_image_from_base64(base64_str: str, mode='RGB'):
"""Decodes a Base64 string and opens it as a PIL Image."""
if "," in base64_str:
base64_str = base64_str.split(",")[1]
try:
img_data = base64.b64decode(base64_str)
img = Image.open(io.BytesIO(img_data))
return img.convert(mode) # Convert to the specified mode (e.g., 'RGB', 'L', etc.)
except Exception as e:
print(f"Error loading image from base64: {e}")
return None
# === API ENDPOINTS (CORRECT AND GUARANTEED) ===
@app.get("/")
def root():
return {"status": "API server is running. Model will load on first call."}
class ApiInput(BaseModel):
person_base64: str
fabric_base64: str
mask_base64: Optional[str] = None
@app.post("/generate")
async def api_generate(request: Request, inputs: ApiInput):
# Ensure the model is loaded
load_model()
API_KEY = os.environ.get("API_KEY")
if request.headers.get("x-api-key") != API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized")
# Load person and fabric images from base64
person = load_image_from_base64(inputs.person_base64)
fabric = load_image_from_base64(inputs.fabric_base64)
# if inputs.mask_base64 and inputs.mask_base64 != "" :
# mask = load_image_from_base64(inputs.mask_base64, mode='L')
# print( "Load image from base64 for mask image ..." );
# else:
# mask = generate_ultimate_mask(person_resized)
# print( "Generating the mask image ..." );
if person is None or fabric is None:
raise HTTPException(status_code=400, detail="Could not decode base64 images.")
# Resize person image to a standard size
TARGET_SIZE = (1024, 1024)
person_resized = person.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
# Handle mask image if provided
if inputs.mask_base64 and inputs.mask_base64 != "":
mask = load_image_from_base64(inputs.mask_base64, mode='L')
if mask is None:
raise HTTPException(status_code=400, detail="Could not decode mask base64.")
mask = mask.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
else:
# If no mask is provided, generate one
mask = generate_ultimate_mask(person_resized)
# Process and create the final results
final_results = create_the_final_results(fabric, person_resized, mask)
# Convert image to base64 for the response
def to_base64(img):
img_display = img.resize((512, 512), Image.Resampling.LANCZOS)
buf = io.BytesIO()
img_display.save(buf, format="PNG")
return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}"
# Prepare the response data with images converted to base64
response_data = {
'ultimate_image': to_base64(final_results['ultimate_image']),
# 'fine_weave_image': to_base64(final_results['fine_weave_image']),
# 'bold_statement_image': to_base64(final_results['bold_statement_image']),
# 'creative_variation_image': to_base64(final_results['creative_variation_image']),
'mask_image': to_base64(mask)
}
return response_data