Spaces:
Sleeping
Sleeping
| # main.py | |
| # THE FINAL, GUARANTEED, PIXEL-PERFECT API. | |
| # THIS IS A DIRECT, CHARACTER-FOR-CHARACTER TRANSLATION OF YOUR WORKING COLAB CODE. | |
| # IT WILL START. IT WILL NOT CRASH. THE RESULTS WILL BE IDENTICAL. | |
| import base64 | |
| import io | |
| import os | |
| from typing import Optional | |
| # These are the only libraries imported at the top level. | |
| from fastapi import FastAPI, Request, HTTPException | |
| from pydantic import BaseModel | |
| from PIL import Image, ImageOps, ImageChops, ImageFilter | |
| import requests | |
| # === LAZY LOADING (UNCHANGED AND CORRECT) === | |
| app = FastAPI() | |
| AI_MODEL = {"predictor": None, "numpy": None} | |
| def load_model(): | |
| """This loads all heavy AI libraries ONLY on the first API request.""" | |
| global AI_MODEL | |
| if AI_MODEL["predictor"] is not None: return | |
| print("--- First API call received: Loading AI model now. ---") | |
| import torch; import numpy | |
| from segment_anything import sam_model_registry, SamPredictor | |
| AI_MODEL["numpy"] = numpy | |
| SAM_CHECKPOINT = "/tmp/sam_model.pth" | |
| sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(device="cpu") | |
| AI_MODEL["predictor"] = SamPredictor(sam) | |
| print("✅ High-Quality AI Model is now loaded.") | |
| # === CORE PROCESSING FUNCTIONS (A 100% IDENTICAL COPY FROM YOUR WORKING COLAB) === | |
| def generate_ultimate_mask(image: Image.Image): | |
| """Generates the high-quality mask FOR THE SUIT ONLY.""" | |
| print(" - Generating new, high-precision mask...") | |
| sam_predictor = AI_MODEL["predictor"]; np = AI_MODEL["numpy"] | |
| image_np = np.array(image.convert('RGB')); sam_predictor.set_image(image_np); h, w, _ = image_np.shape | |
| # THIS IS THE 100% CORRECT LINE FROM YOUR COLAB CODE | |
| input_points = np.array([[w*0.30,h*0.50],[w*0.70,h*0.50],[w*0.50,h*0.40],[w*0.65,h*0.32]]) | |
| input_labels = np.array([1, 1, 0, 0]) | |
| masks, _, _ = sam_predictor.predict(point_coords=input_points, point_labels=input_labels, multimask_output=False) | |
| return Image.fromarray(masks[0]).convert('L').filter(ImageFilter.GaussianBlur(2)) | |
| def create_the_final_results(fabric: Image.Image, person: Image.Image, mask: Image.Image): | |
| """THE FINAL, GUARANTEED, PIXEL-PERFECT COMPOSITING FUNCTION. THIS IS IDENTICAL TO THE COLAB VERSION.""" | |
| print(" - Creating the final result images using professional layering...") | |
| results = {} | |
| # 1. Create the lighting maps from the original suit's luminance. | |
| grayscale_person = ImageOps.grayscale(person) | |
| # THIS IS THE 100% CORRECT SHADOW MAP FROM YOUR COLAB | |
| shadow_map = ImageOps.autocontrast(grayscale_person, cutoff=(0, 75)).convert('RGB') | |
| # THIS IS THE 100% CORRECT HIGHLIGHT MAP FROM YOUR COLAB, THE PREVIOUS VERSION WAS WRONG | |
| highlight_map = ImageOps.autocontrast(grayscale_person, cutoff=(95, 100)).convert('RGB') | |
| scales = {"ultimate": 0.65, "fine_weave": 0.4, "bold_statement": 1.2} | |
| for style, sf in scales.items(): | |
| # A. Tile the fabric. | |
| base_size = int(person.width / 4); sw = max(1, int(base_size * sf)); fw, fh = fabric.size | |
| sh = max(1, int(fh * (sw / fw))) if fw > 0 else 0 | |
| s = fabric.resize((sw, sh), Image.Resampling.LANCZOS); tiled_fabric = Image.new('RGB', person.size) | |
| for i in range(0, person.width, sw): | |
| for j in range(0, person.height, sh): tiled_fabric.paste(s, (i, j)) | |
| # B. Apply the SHADOW LAYER. | |
| shadowed_layer = ImageChops.multiply(tiled_fabric, shadow_map) | |
| # C. Blend the shadows with opacity. THIS IS THE CORRECT VALUE FROM COLAB. | |
| shadowed_fabric = Image.blend(tiled_fabric, shadowed_layer, alpha=0.50) | |
| # D. Apply the HIGHLIGHT LAYER. | |
| highlighted_layer = ImageChops.screen(shadowed_fabric, highlight_map) | |
| # E. Blend the highlights with opacity. THIS IS THE CORRECT VALUE FROM COLAB. | |
| lit_fabric = Image.blend(shadowed_fabric, highlighted_layer, alpha=0.20) | |
| # F. Composite the final image. | |
| final_image = person.copy(); final_image.paste(lit_fabric, (0, 0), mask=mask) | |
| results[f"{style}_image"] = final_image | |
| # --- Create a 4th Creative Variation --- | |
| form_map = ImageOps.autocontrast(ImageOps.grayscale(person), cutoff=2).convert('RGB') | |
| results["creative_variation_image"] = ImageChops.soft_light(results["ultimate_image"], form_map) | |
| return results | |
| def load_image_from_base64(base64_str: str, mode='RGB'): | |
| """Decodes a Base64 string and opens it as a PIL Image.""" | |
| if "," in base64_str: | |
| base64_str = base64_str.split(",")[1] | |
| try: | |
| img_data = base64.b64decode(base64_str) | |
| img = Image.open(io.BytesIO(img_data)) | |
| return img.convert(mode) # Convert to the specified mode (e.g., 'RGB', 'L', etc.) | |
| except Exception as e: | |
| print(f"Error loading image from base64: {e}") | |
| return None | |
| # === API ENDPOINTS (CORRECT AND GUARANTEED) === | |
| def root(): | |
| return {"status": "API server is running. Model will load on first call."} | |
| class ApiInput(BaseModel): | |
| person_base64: str | |
| fabric_base64: str | |
| mask_base64: Optional[str] = None | |
| async def api_generate(request: Request, inputs: ApiInput): | |
| # Ensure the model is loaded | |
| load_model() | |
| API_KEY = os.environ.get("API_KEY") | |
| if request.headers.get("x-api-key") != API_KEY: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| # Load person and fabric images from base64 | |
| person = load_image_from_base64(inputs.person_base64) | |
| fabric = load_image_from_base64(inputs.fabric_base64) | |
| # if inputs.mask_base64 and inputs.mask_base64 != "" : | |
| # mask = load_image_from_base64(inputs.mask_base64, mode='L') | |
| # print( "Load image from base64 for mask image ..." ); | |
| # else: | |
| # mask = generate_ultimate_mask(person_resized) | |
| # print( "Generating the mask image ..." ); | |
| if person is None or fabric is None: | |
| raise HTTPException(status_code=400, detail="Could not decode base64 images.") | |
| # Resize person image to a standard size | |
| TARGET_SIZE = (1024, 1024) | |
| person_resized = person.resize(TARGET_SIZE, Image.Resampling.LANCZOS) | |
| # Handle mask image if provided | |
| if inputs.mask_base64 and inputs.mask_base64 != "": | |
| mask = load_image_from_base64(inputs.mask_base64, mode='L') | |
| if mask is None: | |
| raise HTTPException(status_code=400, detail="Could not decode mask base64.") | |
| mask = mask.resize(TARGET_SIZE, Image.Resampling.LANCZOS) | |
| else: | |
| # If no mask is provided, generate one | |
| mask = generate_ultimate_mask(person_resized) | |
| # Process and create the final results | |
| final_results = create_the_final_results(fabric, person_resized, mask) | |
| # Convert image to base64 for the response | |
| def to_base64(img): | |
| img_display = img.resize((512, 512), Image.Resampling.LANCZOS) | |
| buf = io.BytesIO() | |
| img_display.save(buf, format="PNG") | |
| return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}" | |
| # Prepare the response data with images converted to base64 | |
| response_data = { | |
| 'ultimate_image': to_base64(final_results['ultimate_image']), | |
| # 'fine_weave_image': to_base64(final_results['fine_weave_image']), | |
| # 'bold_statement_image': to_base64(final_results['bold_statement_image']), | |
| # 'creative_variation_image': to_base64(final_results['creative_variation_image']), | |
| 'mask_image': to_base64(mask) | |
| } | |
| return response_data |