Spaces:
Running
Running
| """ | |
| NAIA-WEB API Service | |
| NAI Image Generation API communication layer | |
| Reference: NAIA2.0/core/api_service.py (260-460) | |
| """ | |
| import aiohttp | |
| import asyncio | |
| import zipfile | |
| import io | |
| import json | |
| import random | |
| import base64 | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, Dict, Any, List | |
| from PIL import Image | |
| from utils.constants import NAI_API_URL, MODEL_ID_MAP | |
| def process_reference_image(file_path: str) -> str: | |
| """ | |
| Process reference image for character reference API. | |
| Normalizes aspect ratio and encodes to base64. | |
| Reference: NAIA2.0/modules/character_reference_module.py _file_to_base64 | |
| """ | |
| try: | |
| original_image = Image.open(file_path) | |
| width, height = original_image.size | |
| aspect_ratio = width / height | |
| # Standard aspect ratios (ratio, canvas_width, canvas_height) | |
| ratios = { | |
| '2:3': (2/3, 1024, 1536), | |
| '3:2': (3/2, 1536, 1024), | |
| '1:1': (1/1, 1472, 1472) | |
| } | |
| # Find closest standard ratio | |
| closest_ratio = min(ratios.keys(), key=lambda k: abs(aspect_ratio - ratios[k][0])) | |
| target_ratio, canvas_width, canvas_height = ratios[closest_ratio] | |
| print(f"NAIA-WEB: Reference image {width}x{height} ({aspect_ratio:.2f}) → {closest_ratio} ({canvas_width}x{canvas_height})") | |
| # Create black canvas | |
| canvas = Image.new('RGB', (canvas_width, canvas_height), (0, 0, 0)) | |
| # Resize to fit canvas (preserve aspect ratio) | |
| if width / canvas_width > height / canvas_height: | |
| new_width = canvas_width | |
| new_height = int(height * (canvas_width / width)) | |
| else: | |
| new_height = canvas_height | |
| new_width = int(width * (canvas_height / height)) | |
| resized_image = original_image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| # Center on canvas | |
| x_offset = (canvas_width - new_width) // 2 | |
| y_offset = (canvas_height - new_height) // 2 | |
| # Handle RGBA transparency | |
| if resized_image.mode == 'RGBA': | |
| canvas = canvas.convert('RGBA') | |
| canvas.paste(resized_image, (x_offset, y_offset), resized_image) | |
| rgb_canvas = Image.new('RGB', (canvas_width, canvas_height), (0, 0, 0)) | |
| rgb_canvas.paste(canvas, (0, 0), canvas) | |
| canvas = rgb_canvas | |
| else: | |
| canvas.paste(resized_image, (x_offset, y_offset)) | |
| # Encode to base64 | |
| buffer = io.BytesIO() | |
| canvas.save(buffer, format="PNG", optimize=False) | |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| except Exception as e: | |
| print(f"NAIA-WEB: Failed to process reference image: {e}") | |
| # Fallback: use original file bytes | |
| with open(file_path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| class NAIAPIError(Exception): | |
| """Custom exception for NAI API errors""" | |
| def __init__(self, status_code: int, message: str, debug_info: Optional[Dict] = None): | |
| self.status_code = status_code | |
| self.message = message | |
| self.debug_info = debug_info or {} | |
| super().__init__(f"NAI API Error ({status_code}): {message}") | |
| class CharacterReferenceData: | |
| """Character reference data for NAID4.5""" | |
| image_base64: str # Base64 encoded image | |
| style_aware: bool = True # Include style from reference | |
| fidelity: float = 0.75 # How closely to follow the reference (0.0-1.0) | |
| class GenerationParameters: | |
| """Parameters for image generation request""" | |
| prompt: str | |
| negative_prompt: str | |
| width: int | |
| height: int | |
| steps: int = 28 | |
| scale: float = 5.0 | |
| cfg_rescale: float = 0.4 # NAIA2.0 default | |
| sampler: str = "k_euler" | |
| seed: Optional[int] = None | |
| model: str = "NAID4.5F" | |
| noise_schedule: str = "native" | |
| variety_plus: bool = False # VAR+ option (skip_cfg_above_sigma) | |
| # Character prompts: List of (prompt, negative) tuples | |
| character_prompts: List[Tuple[str, str]] = None | |
| # Character reference (NAID4.5 feature) | |
| character_reference: Optional[CharacterReferenceData] = None | |
| class NAIAPIService: | |
| """ | |
| Service for communicating with NAI image generation API. | |
| Handles V4.5 model API calls with proper payload structure. | |
| """ | |
| def __init__(self): | |
| self._session: Optional[aiohttp.ClientSession] = None | |
| # Debug info storage | |
| self._last_payload: Optional[Dict] = None | |
| self._last_response_status: Optional[int] = None | |
| self._last_response_text: Optional[str] = None | |
| async def _get_session(self) -> aiohttp.ClientSession: | |
| """Get or create aiohttp session""" | |
| if self._session is None or self._session.closed: | |
| self._session = aiohttp.ClientSession() | |
| return self._session | |
| async def generate_image( | |
| self, | |
| token: str, | |
| params: GenerationParameters | |
| ) -> Tuple[Image.Image, Dict[str, Any]]: | |
| """ | |
| Call NAI API to generate an image. | |
| Args: | |
| token: NAI API authentication token | |
| params: Generation parameters | |
| Returns: | |
| Tuple of (PIL Image, metadata dict) | |
| Raises: | |
| NAIAPIError: If API call fails | |
| """ | |
| session = await self._get_session() | |
| # Get model name from mapping | |
| model_name = MODEL_ID_MAP.get(params.model, "nai-diffusion-4-5-full") | |
| # Determine seed | |
| seed = params.seed if params.seed and params.seed > 0 else random.randint(0, 2**32 - 1) | |
| # Build V4 prompt structure | |
| v4_prompt = { | |
| "caption": { | |
| "base_caption": params.prompt, | |
| "char_captions": [] | |
| }, | |
| "use_coords": False, | |
| "use_order": True | |
| } | |
| v4_negative_prompt = { | |
| "caption": { | |
| "base_caption": params.negative_prompt, | |
| "char_captions": [] | |
| }, | |
| "legacy_uc": False | |
| } | |
| # Add character prompts if provided (NAID4.5 feature) | |
| if params.character_prompts: | |
| for char_prompt, char_negative in params.character_prompts: | |
| if char_prompt.strip(): | |
| # Default center position (no 5x5 grid feature) | |
| centers = [{"x": 0.5, "y": 0.5}] | |
| v4_prompt["caption"]["char_captions"].append({ | |
| "char_caption": char_prompt.strip(), | |
| "centers": centers | |
| }) | |
| v4_negative_prompt["caption"]["char_captions"].append({ | |
| "char_caption": char_negative.strip() if char_negative else "", | |
| "centers": centers | |
| }) | |
| if v4_prompt["caption"]["char_captions"]: | |
| print(f"NAIA-WEB: Added {len(v4_prompt['caption']['char_captions'])} character prompt(s)") | |
| # Build API parameters (matching NAI V4 structure) | |
| api_parameters = { | |
| "width": params.width, | |
| "height": params.height, | |
| "n_samples": 1, | |
| "seed": seed, | |
| "extra_noise_seed": seed, | |
| "sampler": params.sampler, | |
| "steps": params.steps, | |
| "scale": params.scale, | |
| "cfg_rescale": params.cfg_rescale, | |
| "noise_schedule": params.noise_schedule, | |
| "negative_prompt": params.negative_prompt, | |
| # V4 specific parameters | |
| "params_version": 3, | |
| "add_original_image": True, | |
| "legacy": False, | |
| "legacy_uc": False, | |
| "autoSmea": True, | |
| "prefer_brownian": True, | |
| "ucPreset": 0, | |
| "use_coords": False, | |
| "v4_prompt": v4_prompt, | |
| "v4_negative_prompt": v4_negative_prompt, | |
| } | |
| # VAR+ (skip_cfg_above_sigma) handling | |
| # Reference: NAIA2.0/core/api_service.py:307-321 | |
| if params.variety_plus: | |
| # NAID4.5: 58, NAID4.0/NAID3: 19 | |
| if model_name in ['nai-diffusion-4-5-curated']: | |
| api_parameters["skip_cfg_above_sigma"] = 58 | |
| elif model_name == 'nai-diffusion-4-5-full': | |
| api_parameters["skip_cfg_above_sigma"] = 58.93178654671047 | |
| else: | |
| api_parameters["skip_cfg_above_sigma"] = 19 | |
| print(f"NAIA-WEB: VAR+ enabled (skip_cfg_above_sigma={api_parameters['skip_cfg_above_sigma']})") | |
| else: | |
| api_parameters["skip_cfg_above_sigma"] = None | |
| # Add character reference if provided (NAID4.5 feature) | |
| if params.character_reference: | |
| ref = params.character_reference | |
| # Build description based on style_aware setting | |
| if ref.style_aware: | |
| description = { | |
| "caption": {"base_caption": "character&style", "char_captions": []}, | |
| "legacy_uc": False | |
| } | |
| else: | |
| description = { | |
| "caption": {"base_caption": "character", "char_captions": []}, | |
| "legacy_uc": False | |
| } | |
| api_parameters["director_reference_descriptions"] = [description] | |
| api_parameters["director_reference_images"] = [ref.image_base64] | |
| api_parameters["director_reference_information_extracted"] = [1] | |
| api_parameters["director_reference_secondary_strength_values"] = [ref.fidelity] | |
| api_parameters["director_reference_strength_values"] = [1] | |
| api_parameters["controlnet_strength"] = 1 | |
| api_parameters["inpaintImg2ImgStrength"] = 1 | |
| api_parameters["normalize_reference_strength_multiple"] = True | |
| # Character Reference 활성화 시 skip_cfg_above_sigma 제거 | |
| # Reference: NAIA2.0/core/api_service.py:533-536 | |
| if 'skip_cfg_above_sigma' in api_parameters: | |
| del api_parameters['skip_cfg_above_sigma'] | |
| print("NAIA-WEB: skip_cfg_above_sigma removed (Character Reference enabled)") | |
| print(f"NAIA-WEB: Character reference enabled (style_aware={ref.style_aware}, fidelity={ref.fidelity})") | |
| # Build request payload | |
| payload = { | |
| "input": params.prompt, | |
| "model": model_name, | |
| "action": "generate", | |
| "parameters": api_parameters | |
| } | |
| # Headers - matching NAIA2.0 (no Accept header) | |
| headers = { | |
| "Authorization": f"Bearer {token}", | |
| "Content-Type": "application/json" | |
| } | |
| # Store for debugging | |
| self._last_payload = payload | |
| self._last_response_status = None | |
| self._last_response_text = None | |
| max_retries = 2 | |
| last_error = None | |
| for attempt in range(max_retries): | |
| try: | |
| async with session.post( | |
| NAI_API_URL, | |
| json=payload, | |
| headers=headers, | |
| timeout=aiohttp.ClientTimeout(total=180) # NAIA2.0 uses 180s | |
| ) as response: | |
| self._last_response_status = response.status | |
| if response.status == 200: | |
| zip_data = await response.read() | |
| image = self._extract_image_from_zip(zip_data) | |
| metadata = { | |
| "seed": seed, | |
| "model": params.model, | |
| "steps": params.steps, | |
| "scale": params.scale, | |
| "sampler": params.sampler, | |
| "width": params.width, | |
| "height": params.height, | |
| } | |
| return image, metadata | |
| else: | |
| error_text = await response.text() | |
| self._last_response_text = error_text | |
| debug_info = { | |
| "model": model_name, | |
| "status": response.status, | |
| "response": error_text[:500], # Truncate long responses | |
| "token_length": len(token) if token else 0, | |
| "token_prefix": token[:10] + "..." if token and len(token) > 10 else token | |
| } | |
| last_error = NAIAPIError(response.status, error_text, debug_info) | |
| # Don't retry on client errors (4xx) | |
| if 400 <= response.status < 500: | |
| raise last_error | |
| except aiohttp.ClientError as e: | |
| self._last_response_text = str(e) | |
| last_error = NAIAPIError(0, f"Network error: {str(e)}") | |
| # Wait before retry | |
| if attempt < max_retries - 1: | |
| await asyncio.sleep(1) | |
| raise last_error or NAIAPIError(0, "Unknown error") | |
| def _extract_image_from_zip(self, zip_data: bytes) -> Image.Image: | |
| """Extract image from NAI response zip""" | |
| with zipfile.ZipFile(io.BytesIO(zip_data)) as zf: | |
| # Find PNG file in zip | |
| image_files = [f for f in zf.namelist() if f.endswith('.png')] | |
| if not image_files: | |
| raise NAIAPIError(0, "No image found in response") | |
| image_bytes = zf.read(image_files[0]) | |
| return Image.open(io.BytesIO(image_bytes)) | |
| async def close(self): | |
| """Close the aiohttp session""" | |
| if self._session and not self._session.closed: | |
| await self._session.close() | |
| def get_debug_info(self) -> Dict[str, Any]: | |
| """Return debug info from last request""" | |
| return { | |
| "last_status": self._last_response_status, | |
| "last_response": self._last_response_text, | |
| "last_payload_keys": list(self._last_payload.keys()) if self._last_payload else None, | |
| "last_model": self._last_payload.get("model") if self._last_payload else None, | |
| } | |
| def format_api_error(error: NAIAPIError) -> str: | |
| """Format API error for user display with debug info""" | |
| base_msg = "" | |
| if error.status_code == 401: | |
| base_msg = "Authentication failed. Please check your API token." | |
| elif error.status_code == 402: | |
| base_msg = "Insufficient Anlas. Please check your account balance." | |
| elif error.status_code == 429: | |
| base_msg = "Rate limited. Please wait before trying again." | |
| elif error.status_code >= 500: | |
| base_msg = "NAI server error. Please try again later." | |
| elif error.status_code == 0: | |
| base_msg = f"Connection error: {error.message}" | |
| else: | |
| base_msg = f"API Error ({error.status_code}): {error.message}" | |
| # Add debug info if available | |
| if error.debug_info: | |
| debug_parts = [] | |
| if "token_length" in error.debug_info: | |
| debug_parts.append(f"Token length: {error.debug_info['token_length']}") | |
| if "token_prefix" in error.debug_info: | |
| debug_parts.append(f"Token prefix: {error.debug_info['token_prefix']}") | |
| if "model" in error.debug_info: | |
| debug_parts.append(f"Model: {error.debug_info['model']}") | |
| if "response" in error.debug_info: | |
| debug_parts.append(f"Response: {error.debug_info['response']}") | |
| if debug_parts: | |
| base_msg += "\n\n[Debug Info]\n" + "\n".join(debug_parts) | |
| return base_msg | |