File size: 20,118 Bytes
83039b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""
FastAPI server for Virtual Try-On using multiple model providers.

This server provides a simple REST API endpoint for virtual try-on generation
using various image generation models:
- Nano Banana and Nano Banana Pro (Google Gemini)
- FLUX 2 Pro and FLUX 2 Flex (Black Forest Labs)
"""

import os
import io
import base64
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Tuple
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from PIL import Image
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Import adapters
from tryon.api.nano_banana import NanoBananaAdapter, NanoBananaProAdapter
from tryon.api.flux2 import Flux2ProAdapter, Flux2FlexAdapter

# Create output directory for generated images
OUTPUT_DIR = Path("outputs/virtual_tryon")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Supported aspect ratios for both adapters
SUPPORTED_ASPECT_RATIOS = [
    "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"
]


def calculate_aspect_ratio(image: Image.Image) -> str:
    """
    Calculate the aspect ratio from an image and return the closest supported ratio.
    
    Args:
        image: PIL Image object
        
    Returns:
        str: Aspect ratio string in format "W:H" (e.g., "16:9")
    """
    width, height = image.size
    ratio = width / height
    
    # Map of supported ratios to their decimal values
    ratio_map = {
        "1:1": 1.0,
        "2:3": 2/3,
        "3:2": 3/2,
        "3:4": 3/4,
        "4:3": 4/3,
        "4:5": 4/5,
        "5:4": 5/4,
        "9:16": 9/16,
        "16:9": 16/9,
        "21:9": 21/9,
    }
    
    # Find the closest matching aspect ratio
    closest_ratio = "1:1"  # Default
    min_diff = float('inf')
    
    for ratio_str, ratio_value in ratio_map.items():
        diff = abs(ratio - ratio_value)
        if diff < min_diff:
            min_diff = diff
            closest_ratio = ratio_str
    
    return closest_ratio


def get_image_dimensions(image: Image.Image) -> Tuple[int, int]:
    """
    Get image dimensions (width, height).
    
    Args:
        image: PIL Image object
        
    Returns:
        tuple: (width, height)
    """
    return image.size


def calculate_resolution(image: Image.Image) -> str:
    """
    Calculate resolution from image dimensions in "widthxheight" format.
    
    Args:
        image: PIL Image object
        
    Returns:
        str: Resolution string in format "widthxheight" (e.g., "1024x1024")
    """
    width, height = image.size
    return f"{width}x{height}"


def map_resolution_to_pro_format(image: Image.Image) -> str:
    """
    Map image resolution to Nano Banana Pro format ("1K", "2K", or "4K").
    
    The mapping is based on the maximum dimension:
    - max_dimension <= 1500: "1K"
    - max_dimension <= 3000: "2K"
    - max_dimension > 3000: "4K"
    
    Args:
        image: PIL Image object
        
    Returns:
        str: Resolution in format "1K", "2K", or "4K"
    """
    width, height = image.size
    max_dimension = max(width, height)
    
    if max_dimension <= 1500:
        return "1K"
    elif max_dimension <= 3000:
        return "2K"
    else:
        return "4K"

app = FastAPI(
    title="TryOn AI Virtual Try-On API",
    description="Virtual try-on API using multiple model providers (Nano Banana, FLUX 2 Pro, FLUX 2 Flex)",
    version="1.0.0"
)

# CORS middleware to allow requests from frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:3000", 
        "http://127.0.0.1:3000",
        "http://localhost:5173",
        "https://fyp-frontend-sandy.vercel.app",
        "*"  # Allow all origins for Hugging Face deployment (you can restrict this later)
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
async def root():
    """Root endpoint."""
    return {
        "message": "TryOn AI Virtual Try-On API",
        "version": "1.0.0",
        "endpoints": {
            "POST /api/v1/virtual-tryon": "Generate virtual try-on image"
        },
        "providers": [
            "nano-banana",
            "nano-banana-pro",
            "flux-2-pro",
            "flux-2-flex"
        ]
    }


@app.get("/health")
async def health():
    """Health check endpoint."""
    return {"status": "healthy"}


@app.post("/api/v1/virtual-tryon")
async def virtual_tryon(
    model_image: UploadFile = File(..., description="Model/person image"),
    garment_images: List[UploadFile] = File(..., description="Garment images"),
    provider: str = Form(default="nano-banana", description="Provider: 'nano-banana', 'nano-banana-pro', 'flux-2-pro', or 'flux-2-flex'"),
    prompt: Optional[str] = Form(default=None, description="Optional custom prompt"),
    resolution: Optional[str] = Form(default="1K", description="Resolution for nano-banana-pro: '1K', '2K', or '4K'"),
    aspect_ratio: Optional[str] = Form(default=None, description="Optional aspect ratio (e.g., '16:9')"),
    width: Optional[int] = Form(default=None, description="Output image width (for FLUX 2 models)"),
    height: Optional[int] = Form(default=None, description="Output image height (for FLUX 2 models)"),
    seed: Optional[int] = Form(default=None, description="Random seed for reproducibility (for FLUX 2 models)"),
    guidance: Optional[float] = Form(default=None, description="Guidance scale 1.5-10 (for FLUX 2 Flex, default: 3.5)"),
    steps: Optional[int] = Form(default=None, description="Number of generation steps (for FLUX 2 Flex, default: 28)"),
    safety_tolerance: Optional[int] = Form(default=2, description="Safety tolerance 0-5 (for FLUX 2 models, default: 2)")
):
    """
    Generate virtual try-on image from model image and garment images.
    
    Uses multi-image composition feature of various models to combine
    the model image with multiple garment images.
    
    Supported providers:
    - nano-banana: Google Gemini Nano Banana (basic)
    - nano-banana-pro: Google Gemini Nano Banana Pro (supports resolution)
    - flux-2-pro: Black Forest Labs FLUX 2 Pro (high quality)
    - flux-2-flex: Black Forest Labs FLUX 2 Flex (advanced controls)
    
    Args:
        model_image: Single model/person image
        garment_images: List of garment images (top, jeans, scarf, hat, etc.)
        provider: Model provider ('nano-banana', 'nano-banana-pro', 'flux-2-pro', or 'flux-2-flex')
        prompt: Optional custom prompt for generation
        resolution: Resolution for nano-banana-pro ('1K', '2K', or '4K')
        aspect_ratio: Optional aspect ratio (for Nano Banana models)
        width: Output image width in pixels (for FLUX 2 models, minimum: 64)
        height: Output image height in pixels (for FLUX 2 models, minimum: 64)
        seed: Random seed for reproducibility (for FLUX 2 models)
        guidance: Guidance scale 1.5-10 (for FLUX 2 Flex only, default: 3.5)
        steps: Number of generation steps (for FLUX 2 Flex only, default: 28)
        safety_tolerance: Safety tolerance 0-5 (for FLUX 2 models, default: 2)
        
    Returns:
        JSON response with base64-encoded result image
    """
    try:
        # Validate provider
        valid_providers = ["nano-banana", "nano-banana-pro", "flux-2-pro", "flux-2-flex"]
        if provider not in valid_providers:
            raise HTTPException(
                status_code=400,
                detail=f"Invalid provider '{provider}'. Must be one of: {', '.join(valid_providers)}"
            )
    
        # Validate inputs
        if not model_image:
            raise HTTPException(status_code=400, detail="Model image is required")
    
        if not garment_images or len(garment_images) == 0:
            raise HTTPException(status_code=400, detail="At least one garment image is required")
    
        # Read model image
        model_image_bytes = await model_image.read()
        model_pil = Image.open(io.BytesIO(model_image_bytes))
    
        # Calculate aspect ratio and resolution from model image
        calculated_aspect_ratio = calculate_aspect_ratio(model_pil)
        calculated_resolution = calculate_resolution(model_pil)
        model_width, model_height = get_image_dimensions(model_pil)
    
        # Use calculated aspect ratio if not provided, otherwise use the provided one
        final_aspect_ratio = aspect_ratio if aspect_ratio else calculated_aspect_ratio
    
        # Map resolution to appropriate format based on provider
        # For nano-banana-pro, use "1K", "2K", or "4K" format
        # For nano-banana, resolution is not used (only aspect ratio)
        # For FLUX 2 models, use width/height parameters instead
        if provider == "nano-banana-pro":
            # Use provided resolution if valid, otherwise map from image dimensions
            if resolution and resolution in ["1K", "2K", "4K"]:
                final_resolution = resolution
            else:
                final_resolution = map_resolution_to_pro_format(model_pil)
        elif provider in ["flux-2-pro", "flux-2-flex"]:
            # For FLUX 2, use model dimensions as default width/height if not provided
            if width is None:
                width = model_width
            if height is None:
                height = model_height
            final_resolution = f"{width}x{height}"
        else:
            # For nano-banana, resolution is not used, but keep calculated for reference
            final_resolution = calculated_resolution
    
        # Read garment images and combine with model image
        # First image should be model, followed by garments
        images_list = [model_pil]
        for garment_file in garment_images:
            garment_bytes = await garment_file.read()
            garment_pil = Image.open(io.BytesIO(garment_bytes))
            images_list.append(garment_pil)
    
        # Prepare prompt
        if not prompt:
            prompt = (
                "Create a realistic virtual try-on image showing the person wearing the provided garments. "
                "CRITICAL REQUIREMENTS - Preserve all details exactly:\n"
                "1. GARMENT EXTRACTION: The garment images may contain people wearing the garments. "
                "IGNORE and EXTRACT ONLY the garment itself - do not use any person, model, or human figure "
                "from the garment images. Focus solely on the garment: its shape, design, patterns, colors, "
                "textures, and all visual details. Remove or ignore any human elements from garment images.\n"
                "2. GARMENT PRESERVATION: Keep ALL garment details completely intact - patterns, colors, textures, "
                "designs, prints, logos, text, embroidery, sequins, and any decorative elements must remain "
                "identical to the original garment images. Do not alter, fade, or modify any garment features.\n"
                "3. PERSON PRESERVATION: Keep the person's face, body shape, skin tone, hair, and physical "
                "characteristics exactly as shown in the FIRST image (model image). Only apply the extracted "
                "garments from the subsequent images to this person. Do not use any person from garment images.\n"
                "4. PARTIAL GARMENT HANDLING: If the person in the model image is wearing a full-body outfit "
                "(dress, jumpsuit, etc.) but the provided garment is only upper-body (top, shirt, blouse) or "
                "lower-body (pants, jeans, skirt), place the provided garment correctly over the corresponding "
                "body part. For the remaining uncovered body parts, generate an appropriate complementary garment "
                "that matches: (a) the person's physical characteristics and body type, (b) the person's style "
                "and personality traits visible in the model image, (c) the style, color scheme, and design "
                "aesthetic of the provided garment. The complementary garment should create a cohesive, "
                "harmonious outfit that looks natural and well-coordinated.\n"
                "5. FITTING: The extracted garments should fit naturally on the person's body from the first image, "
                "following their body contours and proportions realistically, while maintaining all original "
                "garment details from the garment images.\n"
                "6. COMPOSITION: The first image is the model/person to dress. The following images contain "
                "garments (top, bottom, accessories, etc.) - extract ONLY the garments from these images, "
                "ignoring any people shown. Combine the extracted garments to create a cohesive outfit where "
                "each garment maintains its original appearance and fits the person naturally.\n"
                "7. REALISM: The final image should look like a professional photograph of the person from the "
                "first image wearing the exact extracted garments (and complementary garments if needed), with "
                "realistic lighting, shadows, and fabric draping."
            )
    
        # Initialize adapter and generate
        if provider == "nano-banana":
            adapter = NanoBananaAdapter()
            # Generate with basic adapter using calculated aspect ratio
            result_images = adapter.generate_multi_image(
                images=images_list,
                prompt=prompt,
                aspect_ratio=final_aspect_ratio
            )
        elif provider == "nano-banana-pro":
            adapter = NanoBananaProAdapter()
            # Generate with Pro adapter (supports resolution) using calculated resolution and aspect ratio
            result_images = adapter.generate_multi_image(
                images=images_list,
                prompt=prompt,
                resolution=final_resolution,
                aspect_ratio=final_aspect_ratio
            )
        elif provider == "flux-2-pro":
            adapter = Flux2ProAdapter()
            # Generate with FLUX 2 Pro adapter
            # Build kwargs for FLUX 2 Pro
            flux_kwargs = {
                "safety_tolerance": safety_tolerance if safety_tolerance is not None else 2,
                "output_format": "png"
            }
            if width is not None:
                flux_kwargs["width"] = width
            if height is not None:
                flux_kwargs["height"] = height
            if seed is not None:
                flux_kwargs["seed"] = seed
        
            result_images = adapter.generate_multi_image(
                prompt=prompt,
                images=images_list,
                **flux_kwargs
            )
        else:  # flux-2-flex
            adapter = Flux2FlexAdapter()
            # Generate with FLUX 2 Flex adapter (supports guidance and steps)
            # Build kwargs for FLUX 2 Flex
            flux_kwargs = {
                "safety_tolerance": safety_tolerance if safety_tolerance is not None else 2,
                "output_format": "png",
                "guidance": guidance if guidance is not None else 5,
                "steps": steps if steps is not None else 50,
                "prompt_upsampling": True
            }
            if width is not None:
                flux_kwargs["width"] = width
            if height is not None:
                flux_kwargs["height"] = height
            if seed is not None:
                flux_kwargs["seed"] = seed
        
            result_images = adapter.generate_multi_image(
                prompt=prompt,
                images=images_list,
                **flux_kwargs
            )
    
        if not result_images:
            raise HTTPException(status_code=500, detail="No images generated")
    
        # Get first result image and convert to PIL Image
        result_image = result_images[0]
    
        # Convert image to PIL Image if needed
        # FLUX 2 adapters return PIL Images directly
        # Nano Banana adapters return Google GenAI image types that need conversion
        if not isinstance(result_image, Image.Image):
            # Google GenAI image type has image_bytes attribute
            if hasattr(result_image, 'image_bytes'):
                # Convert bytes to PIL Image
                result_image = Image.open(io.BytesIO(result_image.image_bytes))
            elif hasattr(result_image, 'to_pil'):
                # If it has a to_pil method, use it
                result_image = result_image.to_pil()
            else:
                # Try to get bytes from the image object
                try:
                    # Some GenAI image types expose bytes directly
                    image_bytes = bytes(result_image)
                    result_image = Image.open(io.BytesIO(image_bytes))
                except (TypeError, AttributeError):
                    raise HTTPException(
                        status_code=500,
                        detail=f"Unable to convert image type {type(result_image)} to PIL Image. "
                               f"Image attributes: {dir(result_image)}"
                    )
    
        # Generate filename with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        filename = f"tryon_{provider}_{timestamp}.png"
        filepath = OUTPUT_DIR / filename
    
        # Save image to disk
        try:
            # Ensure image is in RGB mode for saving
            if result_image.mode != 'RGB':
                result_image = result_image.convert('RGB')
        
            # Save to file
            result_image.save(str(filepath), 'PNG')
        
            # Also save to BytesIO for base64 encoding
            img_buffer = io.BytesIO()
            result_image.save(img_buffer, 'PNG')
            img_buffer.seek(0)
            img_base64 = base64.b64encode(img_buffer.read()).decode('utf-8')
        
        except Exception as e:
            raise HTTPException(
                status_code=500,
                detail=f"Error saving image: {str(e)}"
            )
    
        # Build response with provider-specific metadata
        response_data = {
            "success": True,
            "image": f"data:image/png;base64,{img_base64}",
            "provider": provider,
            "num_garments": len(garment_images),
            "saved_path": str(filepath),
            "filename": filename,
            "model_dimensions": {"width": model_width, "height": model_height},
        }
    
        # Add provider-specific metadata
        if provider in ["nano-banana", "nano-banana-pro"]:
            response_data.update({
                "aspect_ratio": final_aspect_ratio,
                "calculated_aspect_ratio": calculated_aspect_ratio,
                "resolution": final_resolution,
                "calculated_resolution": calculated_resolution
            })
        elif provider in ["flux-2-pro", "flux-2-flex"]:
            response_data.update({
                "output_dimensions": {"width": width or model_width, "height": height or model_height},
                "safety_tolerance": safety_tolerance if safety_tolerance is not None else 2,
            })
            if seed is not None:
                response_data["seed"] = seed
            if provider == "flux-2-flex":
                response_data.update({
                    "guidance": guidance if guidance is not None else 3.5,
                    "steps": steps if steps is not None else 28,
                })
    
        return JSONResponse(response_data)
    
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        import traceback
        error_details = f"Error generating try-on: {str(e)}\n{traceback.format_exc()}"
        print(error_details)  # Log to console
        raise HTTPException(status_code=500, detail=f"Error generating try-on: {str(e)}")


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)