| |
| |
| |
|
|
| from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Request, Form, Depends, Body |
| from typing import Optional, Tuple |
| from fastapi.responses import FileResponse |
| from huggingface_hub import hf_hub_download |
| import uuid |
| import os |
| import io |
| import json |
| import logging |
| import httpx |
| from urllib.parse import urlparse |
| from PIL import Image |
| import torch |
| from torchvision import transforms |
| import numpy as np |
| from pydantic import BaseModel, EmailStr |
| from app.database import ( |
| get_database, |
| log_api_call, |
| log_image_upload, |
| log_colorization, |
| log_media_click, |
| close_connection, |
| get_category_id_from_collage_maker, |
| get_category_id_from_ai_enhancer, |
| ) |
| try: |
| import firebase_admin |
| from firebase_admin import auth as firebase_auth, app_check, credentials |
| except ImportError: |
| firebase_admin = None |
| firebase_auth = None |
| app_check = None |
| credentials = None |
|
|
| try: |
| import boto3 |
| except ImportError: |
| boto3 = None |
|
|
| |
| try: |
| from app.colorizers import eccv16, siggraph17 |
| from app.colorizers.util import preprocess_img, postprocess_tens |
| CCO_AVAILABLE = True |
| except ImportError as e: |
| print(f"⚠️ CCO colorizers not available: {e}") |
| CCO_AVAILABLE = False |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| app = FastAPI(title="Text-Guided Image Colorization API") |
|
|
| |
| from fastapi.middleware.cors import CORSMiddleware |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| |
| if firebase_admin: |
| try: |
| firebase_json = os.getenv("FIREBASE_CREDENTIALS") |
|
|
| if firebase_json: |
| print("🔥 Loading Firebase credentials from ENV...") |
| firebase_dict = json.loads(firebase_json) |
| cred = credentials.Certificate(firebase_dict) |
| firebase_admin.initialize_app(cred) |
| else: |
| print("⚠️ No Firebase credentials found. Firebase disabled.") |
|
|
| except Exception as e: |
| print("❌ Firebase initialization failed:", e) |
| else: |
| print("⚠️ Firebase Admin SDK not available. Firebase features disabled.") |
|
|
| |
| |
| |
| UPLOAD_DIR = "/tmp/uploads" |
| RESULTS_DIR = "/tmp/results" |
| COMPRESSED_DIR = "/tmp/compressed" |
| os.makedirs(UPLOAD_DIR, exist_ok=True) |
| os.makedirs(RESULTS_DIR, exist_ok=True) |
| os.makedirs(COMPRESSED_DIR, exist_ok=True) |
|
|
| MEDIA_CLICK_DEFAULT_CATEGORY = os.getenv("DEFAULT_CATEGORY_FALLBACK", "69368fcd2e46bd68ae1889b2") |
|
|
| |
| |
| |
| DO_SPACES_KEY = os.getenv("DO_SPACES_KEY") |
| DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET") |
| DO_SPACES_REGION = os.getenv("DO_SPACES_REGION") |
| DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT") |
| DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET") |
| DO_SPACES_BASE_FOLDER = os.getenv("DO_SPACES_BASE_FOLDER", "valentine").strip("/") |
|
|
| _spaces_client = None |
|
|
| def _spaces_enabled() -> bool: |
| return all([DO_SPACES_KEY, DO_SPACES_SECRET, DO_SPACES_REGION, DO_SPACES_ENDPOINT, DO_SPACES_BUCKET]) and boto3 is not None |
|
|
| def _get_spaces_client(): |
| global _spaces_client |
| if _spaces_client is None: |
| _spaces_client = boto3.client( |
| "s3", |
| region_name=DO_SPACES_REGION, |
| endpoint_url=DO_SPACES_ENDPOINT, |
| aws_access_key_id=DO_SPACES_KEY, |
| aws_secret_access_key=DO_SPACES_SECRET, |
| ) |
| return _spaces_client |
|
|
| def _build_spaces_public_url(object_key: str) -> str: |
| parsed = urlparse(DO_SPACES_ENDPOINT if "://" in DO_SPACES_ENDPOINT else f"https://{DO_SPACES_ENDPOINT}") |
| endpoint_host = parsed.netloc or parsed.path |
| endpoint_host = endpoint_host.strip("/") |
| return f"https://{DO_SPACES_BUCKET}.{endpoint_host}/{object_key}" |
|
|
| def upload_bytes_to_spaces(file_bytes: bytes, object_key: str, content_type: str) -> str: |
| if not _spaces_enabled(): |
| raise RuntimeError("DigitalOcean Spaces is not configured") |
| client = _get_spaces_client() |
| client.put_object( |
| Bucket=DO_SPACES_BUCKET, |
| Key=object_key, |
| Body=file_bytes, |
| ACL="public-read", |
| ContentType=content_type, |
| ) |
| return _build_spaces_public_url(object_key) |
|
|
| def upload_file_to_spaces(file_path: str, object_key: str, content_type: str) -> str: |
| if not _spaces_enabled(): |
| raise RuntimeError("DigitalOcean Spaces is not configured") |
| client = _get_spaces_client() |
| with open(file_path, "rb") as f: |
| client.put_object( |
| Bucket=DO_SPACES_BUCKET, |
| Key=object_key, |
| Body=f, |
| ACL="public-read", |
| ContentType=content_type, |
| ) |
| return _build_spaces_public_url(object_key) |
|
|
| |
| |
| |
| MODEL_REPO = "Hammad712/GAN-Colorization-Model" |
| MODEL_FILENAME = "generator.pt" |
|
|
| print("⬇️ Downloading GAN model...") |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME) |
|
|
| print("📦 Loading GAN model weights...") |
| state_dict = torch.load(model_path, map_location="cpu") |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| cco_models = {} |
| if CCO_AVAILABLE: |
| print("📦 Loading CCO models...") |
| try: |
| cco_models["eccv16"] = eccv16(pretrained=True).eval() |
| cco_models["siggraph17"] = siggraph17(pretrained=True).eval() |
| print("✅ CCO models loaded successfully!") |
| except Exception as e: |
| print(f"⚠️ Failed to load CCO models: {e}") |
| CCO_AVAILABLE = False |
|
|
| def colorize_image_gan(img: Image.Image): |
| """ GAN colorizer (dummy implementation - replace with real model.predict) """ |
| transform = transforms.ToTensor() |
| tensor = transform(img.convert("L")).unsqueeze(0) |
| tensor = tensor.repeat(1, 3, 1, 1) |
| output_img = transforms.ToPILImage()(tensor.squeeze()) |
| return output_img |
|
|
| def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"): |
| """ CCO colorizer using eccv16 or siggraph17 model """ |
| if not CCO_AVAILABLE: |
| raise ValueError("CCO models are not available") |
| |
| if model_name not in ["eccv16", "siggraph17"]: |
| model_name = "eccv16" |
| |
| model = cco_models.get(model_name) |
| if model is None: |
| raise ValueError(f"CCO model '{model_name}' not loaded") |
| |
| |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| |
| |
| oimg = np.asarray(img) |
| if oimg.ndim == 2: |
| oimg = np.tile(oimg[:,:,None], 3) |
| |
| |
| (tens_l_orig, tens_l_rs) = preprocess_img(oimg) |
| |
| |
| with torch.no_grad(): |
| out_ab = model(tens_l_rs) |
| |
| |
| output_rgb = postprocess_tens(tens_l_orig, out_ab) |
| |
| |
| output_rgb = np.clip(output_rgb, 0, 1) |
| output_array = (output_rgb * 255).astype(np.uint8) |
| |
| |
| output_img = Image.fromarray(output_array, 'RGB') |
| return output_img |
|
|
| def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"): |
| """ |
| Colorize image using specified model |
| |
| Args: |
| img: PIL Image to colorize |
| model_type: "gan" or "cco" |
| cco_model: "eccv16" or "siggraph17" (only used if model_type is "cco") |
| |
| Returns: |
| Colorized PIL Image |
| """ |
| if model_type == "cco": |
| return colorize_image_cco(img, cco_model) |
| else: |
| return colorize_image_gan(img) |
|
|
| def compress_image_to_target_size(img: Image.Image, target_size_mb: float = 2.5, min_size_mb: float = 2.0, max_size_mb: float = 3.0) -> Tuple[Image.Image, str]: |
| """ |
| Compress image to target file size (2-3MB) by adjusting quality and dimensions |
| |
| Args: |
| img: PIL Image to compress |
| target_size_mb: Target file size in MB (default: 2.5MB) |
| min_size_mb: Minimum acceptable file size in MB (default: 2.0MB) |
| max_size_mb: Maximum acceptable file size in MB (default: 3.0MB) |
| |
| Returns: |
| Tuple of (compressed PIL Image, path to saved compressed image) |
| """ |
| import tempfile |
| |
| |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| |
| |
| original_size = img.size |
| current_img = img.copy() |
| |
| |
| quality = 85 |
| scale_factor = 1.0 |
| |
| |
| for attempt in range(20): |
| |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: |
| temp_path = tmp_file.name |
| |
| try: |
| |
| if scale_factor < 1.0: |
| new_width = int(original_size[0] * scale_factor) |
| new_height = int(original_size[1] * scale_factor) |
| resized_img = current_img.resize((new_width, new_height), Image.Resampling.LANCZOS) |
| else: |
| resized_img = current_img |
| |
| |
| resized_img.save(temp_path, "JPEG", quality=quality, optimize=True) |
| |
| |
| file_size_bytes = os.path.getsize(temp_path) |
| file_size_mb = file_size_bytes / (1024 * 1024) |
| |
| logger.info("Compression attempt %d: quality=%d, scale=%.2f, size=%.2f MB", |
| attempt + 1, quality, scale_factor, file_size_mb) |
| |
| |
| if min_size_mb <= file_size_mb <= max_size_mb: |
| logger.info("Target size achieved: %.2f MB", file_size_mb) |
| return resized_img, temp_path |
| |
| |
| if file_size_mb > max_size_mb: |
| if quality > 30: |
| quality -= 5 |
| elif scale_factor > 0.5: |
| scale_factor -= 0.05 |
| quality = 75 |
| else: |
| |
| logger.warning("Reached minimum compression, file size: %.2f MB", file_size_mb) |
| return resized_img, temp_path |
| |
| |
| elif file_size_mb < min_size_mb: |
| if file_size_mb < 1.0: |
| |
| if quality < 90: |
| quality += 5 |
| elif scale_factor < 1.0: |
| |
| scale_factor = min(1.0, scale_factor + 0.1) |
| elif quality < 95: |
| |
| quality += 2 |
| if quality > 95: |
| quality = 95 |
| else: |
| |
| logger.info("File size %.2f MB is below minimum but at max quality, accepting", file_size_mb) |
| return resized_img, temp_path |
| |
| finally: |
| |
| if os.path.exists(temp_path) and attempt < 19: |
| try: |
| os.unlink(temp_path) |
| except: |
| pass |
| |
| |
| return resized_img, temp_path |
|
|
| |
| |
| |
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize MongoDB on startup""" |
| try: |
| db = get_database() |
| if db is not None: |
| print("✅ MongoDB initialized successfully!") |
| except Exception as e: |
| print(f"⚠️ MongoDB initialization failed: {e}") |
|
|
| @app.on_event("shutdown") |
| async def shutdown_event(): |
| """Cleanup on shutdown""" |
| close_connection() |
| print("Application shutdown") |
|
|
| |
| |
| |
| @app.get("/health") |
| def health_check(request: Request): |
| response = { |
| "status": "healthy", |
| "model_loaded": True, |
| "model_type": "hf_inference_api", |
| "provider": "fal-ai" |
| } |
| |
| |
| log_api_call( |
| endpoint="/health", |
| method="GET", |
| status_code=200, |
| response_data=response, |
| ip_address=request.client.host if request.client else None |
| ) |
| |
| return response |
|
|
| |
| |
| |
| class RegisterRequest(BaseModel): |
| email: EmailStr |
| password: str |
| display_name: Optional[str] = None |
|
|
| class LoginRequest(BaseModel): |
| email: EmailStr |
| password: str |
|
|
| class TokenResponse(BaseModel): |
| id_token: str |
| refresh_token: Optional[str] = None |
| expires_in: int |
| token_type: str = "Bearer" |
| user: dict |
|
|
| |
| |
| |
| def _extract_bearer_token(authorization_header: Optional[str]) -> Optional[str]: |
| if not authorization_header: |
| return None |
| parts = authorization_header.split(" ", 1) |
| if len(parts) == 2 and parts[0].lower() == "bearer": |
| return parts[1].strip() |
| return None |
|
|
| async def verify_request(request: Request): |
| """ |
| Verify Firebase authentication. |
| Priority: |
| 1. Firebase App Check token (X-Firebase-AppCheck header) - Primary method |
| 2. Firebase Auth ID token (Authorization: Bearer header) - Fallback for auth endpoints |
| """ |
| if not firebase_admin: |
| return True |
| |
| if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps: |
| return True |
| |
| |
| app_check_token = request.headers.get("X-Firebase-AppCheck") or request.headers.get("x-firebase-appcheck") |
| if app_check_token and app_check: |
| try: |
| app_check_claims = app_check.verify_token(app_check_token) |
| logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) |
| return True |
| except Exception as e: |
| logger.warning("App Check token verification failed: %s", str(e)) |
| |
| |
| bearer = _extract_bearer_token(request.headers.get("Authorization")) |
| if bearer and firebase_auth: |
| try: |
| decoded = firebase_auth.verify_id_token(bearer) |
| request.state.user = decoded |
| logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) |
| return True |
| except Exception as e: |
| logger.warning("Auth token verification failed: %s", str(e)) |
| |
| |
| return True |
|
|
| from firebase_admin import app_check |
| from fastapi import HTTPException |
| import os |
|
|
| def verify_app_check_token( |
| token: str | None, |
| *, |
| required: bool = False |
| ): |
| """ |
| If required=False: |
| - Missing token is allowed |
| - Invalid token is logged as warning but request is allowed |
| If required=True: |
| - Missing OR invalid token is rejected |
| """ |
|
|
| |
| if not token: |
| if required: |
| raise HTTPException( |
| status_code=401, |
| detail="Firebase App Check token missing" |
| ) |
| return True |
|
|
| |
| try: |
| if not app_check: |
| |
| if required: |
| raise HTTPException( |
| status_code=503, |
| detail="Firebase App Check not configured" |
| ) |
| logger.warning("Firebase App Check not available, but token was provided. Allowing request.") |
| return True |
| |
| |
| app_check.verify_token(token) |
| logger.debug("Firebase App Check token verified successfully") |
| return True |
| except HTTPException: |
| |
| raise |
| except Exception as e: |
| |
| if not required: |
| logger.warning("Invalid Firebase App Check token provided, but allowing request (optional): %s", str(e)) |
| return True |
| |
| logger.error("Invalid Firebase App Check token (required): %s", str(e)) |
| raise HTTPException( |
| status_code=401, |
| detail="Invalid Firebase App Check token" |
| ) |
|
|
|
|
| def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]: |
| """Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click).""" |
| if supplied_user_id and supplied_user_id.strip(): |
| return supplied_user_id.strip() |
| return None |
|
|
| |
| |
| |
| @app.post("/auth/register", response_model=TokenResponse) |
| async def register_user(user_data: RegisterRequest): |
| """ |
| Register a new user with email and password. |
| Returns Firebase ID token for immediate use. |
| """ |
| if not firebase_admin: |
| raise HTTPException(status_code=503, detail="Firebase Admin SDK not available") |
| |
| if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps: |
| raise HTTPException(status_code=503, detail="Firebase not initialized") |
| |
| if not firebase_auth: |
| raise HTTPException(status_code=503, detail="Firebase Auth not available") |
| |
| try: |
| |
| user_record = firebase_auth.create_user( |
| email=user_data.email, |
| password=user_data.password, |
| display_name=user_data.display_name, |
| email_verified=False |
| ) |
| |
| |
| custom_token = firebase_auth.create_custom_token(user_record.uid) |
| |
| logger.info("User registered: %s (uid: %s)", user_data.email, user_record.uid) |
| |
| return TokenResponse( |
| id_token=custom_token.decode('utf-8') if isinstance(custom_token, bytes) else custom_token, |
| token_type="Bearer", |
| expires_in=3600, |
| user={ |
| "uid": user_record.uid, |
| "email": user_record.email, |
| "display_name": user_record.display_name or "", |
| "email_verified": user_record.email_verified |
| } |
| ) |
| except Exception as e: |
| error_msg = str(e) |
| if "already exists" in error_msg.lower() or "email" in error_msg.lower(): |
| raise HTTPException(status_code=400, detail="Email already registered") |
| logger.error("Registration error: %s", error_msg) |
| raise HTTPException(status_code=500, detail=f"Registration failed: {error_msg}") |
|
|
| @app.post("/auth/login", response_model=TokenResponse) |
| async def login_user(credentials: LoginRequest): |
| """ |
| Login with email and password. |
| Uses Firebase REST API to authenticate and get ID token. |
| """ |
| if not firebase_admin: |
| raise HTTPException(status_code=503, detail="Firebase Admin SDK not available") |
| |
| if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps: |
| raise HTTPException(status_code=503, detail="Firebase not initialized") |
| |
| if not firebase_auth: |
| raise HTTPException(status_code=503, detail="Firebase Auth not available") |
| |
| |
| from app.config import settings |
| firebase_api_key = os.getenv("FIREBASE_API_KEY") or getattr(settings, 'FIREBASE_API_KEY', '') |
| |
| if not firebase_api_key: |
| |
| try: |
| user_record = firebase_auth.get_user_by_email(credentials.email) |
| custom_token = firebase_auth.create_custom_token(user_record.uid) |
| |
| logger.info("User login: %s (uid: %s)", credentials.email, user_record.uid) |
| |
| return TokenResponse( |
| id_token=custom_token.decode('utf-8') if isinstance(custom_token, bytes) else custom_token, |
| token_type="Bearer", |
| expires_in=3600, |
| user={ |
| "uid": user_record.uid, |
| "email": user_record.email, |
| "display_name": user_record.display_name or "", |
| "email_verified": user_record.email_verified |
| } |
| ) |
| except Exception as e: |
| error_msg = str(e) |
| if "not found" in error_msg.lower(): |
| raise HTTPException(status_code=401, detail="Invalid email or password") |
| logger.error("Login error: %s", error_msg) |
| raise HTTPException(status_code=500, detail=f"Login failed: {error_msg}") |
| |
| |
| try: |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={firebase_api_key}", |
| json={ |
| "email": credentials.email, |
| "password": credentials.password, |
| "returnSecureToken": True |
| } |
| ) |
| |
| if response.status_code != 200: |
| error_data = response.json() |
| error_msg = error_data.get("error", {}).get("message", "Authentication failed") |
| raise HTTPException(status_code=401, detail=error_msg) |
| |
| data = response.json() |
| logger.info("User login successful: %s", credentials.email) |
| |
| |
| user_record = firebase_auth.get_user(data["localId"]) |
| |
| return TokenResponse( |
| id_token=data["idToken"], |
| refresh_token=data.get("refreshToken"), |
| expires_in=int(data.get("expiresIn", 3600)), |
| token_type="Bearer", |
| user={ |
| "uid": user_record.uid, |
| "email": user_record.email, |
| "display_name": user_record.display_name or "", |
| "email_verified": user_record.email_verified |
| } |
| ) |
| except httpx.HTTPError as e: |
| logger.error("HTTP error during login: %s", str(e)) |
| raise HTTPException(status_code=500, detail="Authentication service unavailable") |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error("Login error: %s", str(e)) |
| raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}") |
|
|
| @app.get("/auth/me") |
| async def get_current_user(request: Request, verified: bool = Depends(verify_request)): |
| """Get current authenticated user information""" |
| if not firebase_admin: |
| raise HTTPException(status_code=503, detail="Firebase Admin SDK not available") |
| |
| if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps: |
| raise HTTPException(status_code=503, detail="Firebase not initialized") |
| |
| if not firebase_auth: |
| raise HTTPException(status_code=503, detail="Firebase Auth not available") |
| |
| |
| if hasattr(request, 'state') and hasattr(request.state, 'user'): |
| user_data = request.state.user |
| uid = user_data.get("uid") |
| |
| try: |
| user_record = firebase_auth.get_user(uid) |
| return { |
| "uid": user_record.uid, |
| "email": user_record.email, |
| "display_name": user_record.display_name or "", |
| "email_verified": user_record.email_verified, |
| } |
| except Exception as e: |
| logger.error("Error getting user: %s", str(e)) |
| raise HTTPException(status_code=404, detail="User not found") |
| |
| raise HTTPException(status_code=401, detail="Not authenticated") |
|
|
| @app.post("/auth/refresh") |
| async def refresh_token(refresh_token_param: str = Body(..., embed=True)): |
| """Refresh Firebase ID token using refresh token""" |
| from app.config import settings |
| firebase_api_key = os.getenv("FIREBASE_API_KEY") or getattr(settings, 'FIREBASE_API_KEY', '') |
| |
| if not firebase_api_key: |
| raise HTTPException(status_code=503, detail="Firebase API key not configured") |
| |
| try: |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| f"https://securetoken.googleapis.com/v1/token?key={firebase_api_key}", |
| json={ |
| "grant_type": "refresh_token", |
| "refresh_token": refresh_token_param |
| } |
| ) |
| |
| if response.status_code != 200: |
| error_data = response.json() |
| error_msg = error_data.get("error", {}).get("message", "Token refresh failed") |
| raise HTTPException(status_code=401, detail=error_msg) |
| |
| data = response.json() |
| return { |
| "id_token": data["id_token"], |
| "refresh_token": data.get("refresh_token"), |
| "expires_in": int(data.get("expires_in", 3600)), |
| "token_type": "Bearer" |
| } |
| except httpx.HTTPError as e: |
| logger.error("HTTP error during token refresh: %s", str(e)) |
| raise HTTPException(status_code=500, detail="Token refresh service unavailable") |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error("Token refresh error: %s", str(e)) |
| raise HTTPException(status_code=500, detail=f"Token refresh failed: {str(e)}") |
|
|
| |
| |
| |
| @app.post("/upload") |
| async def upload_image( |
| request: Request, |
| file: UploadFile = File(...), |
| x_firebase_appcheck: str = Header(None), |
| user_id: Optional[str] = Form(None), |
| category_id: Optional[str] = Form(None), |
| categoryId: Optional[str] = Form(None), |
| ): |
| verify_app_check_token(x_firebase_appcheck) |
| |
| ip_address = request.client.host if request.client else None |
| effective_user_id = _resolve_user_id(request, user_id) |
| effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
| if effective_category_id: |
| effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
| if not effective_category_id: |
| effective_category_id = None |
|
|
| if not file.content_type.startswith("image/"): |
| log_api_call( |
| endpoint="/upload", |
| method="POST", |
| status_code=400, |
| error="Invalid file type", |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=400, detail="Invalid file type") |
|
|
| image_id = f"{uuid.uuid4()}.jpg" |
| file_path = os.path.join(UPLOAD_DIR, image_id) |
|
|
| img_bytes = await file.read() |
| file_size = len(img_bytes) |
| |
| with open(file_path, "wb") as f: |
| f.write(img_bytes) |
|
|
| base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
| response_data = { |
| "success": True, |
| "image_id": image_id.replace(".jpg", ""), |
| "file_url": f"{base_url}/uploads/{image_id}" |
| } |
| |
| |
| log_image_upload( |
| image_id=image_id.replace(".jpg", ""), |
| filename=file.filename or image_id, |
| file_size=file_size, |
| content_type=file.content_type or "image/jpeg", |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| |
| log_api_call( |
| endpoint="/upload", |
| method="POST", |
| status_code=200, |
| request_data={"filename": file.filename, "content_type": file.content_type}, |
| response_data=response_data, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
|
|
| return response_data |
|
|
| |
| |
| |
| @app.post("/colorize") |
| async def colorize( |
| request: Request, |
| file: UploadFile = File(...), |
| x_firebase_appcheck: str = Header(None), |
| user_id: Optional[str] = Form(None), |
| category_id: Optional[str] = Form(None), |
| categoryId: Optional[str] = Form(None), |
| model: Optional[str] = Form(None), |
| appname: Optional[str] = Form(None), |
| ): |
| import time |
| start_time = time.time() |
| |
| verify_app_check_token(x_firebase_appcheck,required=False) |
| |
| ip_address = request.client.host if request.client else None |
| effective_user_id = _resolve_user_id(request, user_id) |
| effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
| if effective_category_id: |
| effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
| if not effective_category_id: |
| effective_category_id = None |
| |
| |
| effective_appname = appname.strip() if appname and isinstance(appname, str) and appname.strip() else None |
| if effective_appname and effective_appname.lower() == "collage-maker": |
| |
| collage_maker_category_id = get_category_id_from_collage_maker() |
| if collage_maker_category_id: |
| |
| if not effective_category_id: |
| effective_category_id = collage_maker_category_id |
| logger.info("Using category ID from collage-maker database: %s", effective_category_id) |
| else: |
| logger.warning("appname is 'collage-maker' but could not fetch category ID from collage-maker database") |
| elif effective_appname and effective_appname.lower() == "ai-enhancer": |
| |
| ai_enhancer_category_id = get_category_id_from_ai_enhancer() |
| if ai_enhancer_category_id: |
| |
| if not effective_category_id: |
| effective_category_id = ai_enhancer_category_id |
| logger.info("Using category ID from AI-enhancer database: %s", effective_category_id) |
| else: |
| logger.warning("appname is 'ai-enhancer' but could not fetch category ID from AI-enhancer database") |
| elif effective_appname: |
| logger.info("appname provided: %s (not 'collage-maker' or 'ai-enhancer', skipping category lookup)", effective_appname) |
|
|
| |
| |
| if CCO_AVAILABLE: |
| model_type = "cco" |
| cco_model = "eccv16" |
| model_type_for_log = "cco-eccv16" |
| else: |
| model_type = "gan" |
| model_type_for_log = "gan" |
| |
| if model: |
| model = model.strip().lower() |
| if model == "gan": |
| |
| model_type = "gan" |
| model_type_for_log = "gan" |
| elif model == "cco" or model.startswith("cco-"): |
| if not CCO_AVAILABLE: |
| error_msg = "CCO models are not available" |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=400, |
| error=error_msg, |
| ip_address=ip_address |
| ) |
| log_colorization( |
| result_id=None, |
| model_type="cco", |
| processing_time=(time.time() - start_time), |
| user_id=effective_user_id, |
| ip_address=ip_address, |
| status="fail", |
| error=error_msg, |
| appname=effective_appname |
| ) |
| raise HTTPException(status_code=400, detail=error_msg) |
| |
| model_type = "cco" |
| if model == "cco-eccv16": |
| cco_model = "eccv16" |
| model_type_for_log = "cco-eccv16" |
| elif model == "cco-siggraph17": |
| cco_model = "siggraph17" |
| model_type_for_log = "cco-siggraph17" |
| else: |
| |
| cco_model = "eccv16" |
| model_type_for_log = "cco-eccv16" |
| else: |
| |
| pass |
|
|
| if not file.content_type.startswith("image/"): |
| error_msg = "Invalid file type" |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=400, |
| error=error_msg, |
| ip_address=ip_address |
| ) |
| |
| log_colorization( |
| result_id=None, |
| model_type=model_type_for_log, |
| processing_time=(time.time() - start_time), |
| user_id=effective_user_id, |
| ip_address=ip_address, |
| status="fail", |
| error=error_msg, |
| appname=effective_appname |
| ) |
| raise HTTPException(status_code=400, detail=error_msg) |
|
|
| try: |
| input_bytes = await file.read() |
| img = Image.open(io.BytesIO(input_bytes)) |
| |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| |
| output_img = colorize_image(img, model_type=model_type, cco_model=cco_model) |
| |
| |
| if output_img.mode != "RGB": |
| output_img = output_img.convert("RGB") |
|
|
| processing_time = time.time() - start_time |
|
|
| result_id = f"{uuid.uuid4()}.png" |
| output_path = os.path.join(RESULTS_DIR, result_id) |
| |
| output_img.save(output_path, "PNG") |
|
|
| |
| logger.info("Creating compressed version targeting 2-3MB file size...") |
| compressed_img, temp_compressed_path = compress_image_to_target_size( |
| output_img, |
| target_size_mb=2.5, |
| min_size_mb=2.0, |
| max_size_mb=3.0 |
| ) |
| |
| |
| compressed_filename = result_id.replace(".png", "_compressed.jpg") |
| compressed_path = os.path.join(COMPRESSED_DIR, compressed_filename) |
| |
| |
| if os.path.exists(temp_compressed_path): |
| import shutil |
| shutil.move(temp_compressed_path, compressed_path) |
| else: |
| compressed_img.save(compressed_path, "JPEG", quality=75, optimize=True) |
| |
| |
| compressed_size_mb = os.path.getsize(compressed_path) / (1024 * 1024) |
| original_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
| logger.info("Original image size: %.2f MB, Compressed image size: %.2f MB", |
| original_size_mb, compressed_size_mb) |
|
|
| base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
|
|
| result_id_clean = result_id.replace(".png", "") |
|
|
| |
| caption = "colorize this image with vibrant natural colors, high quality" |
|
|
| compressed_image_url = f"{base_url}/compressed/{compressed_filename}" |
|
|
| |
| if _spaces_enabled(): |
| try: |
| source_ext = ".jpg" |
| if file.filename and "." in file.filename: |
| source_ext = os.path.splitext(file.filename)[1].lower() or ".jpg" |
| source_object_key = f"{DO_SPACES_BASE_FOLDER}/source/{result_id_clean}{source_ext}" |
| result_object_key = f"{DO_SPACES_BASE_FOLDER}/results/{compressed_filename}" |
|
|
| upload_bytes_to_spaces(input_bytes, source_object_key, file.content_type or "image/jpeg") |
| compressed_image_url = upload_file_to_spaces(compressed_path, result_object_key, "image/jpeg") |
| logger.info( |
| "Uploaded source and compressed result to DO Spaces: %s, %s", |
| source_object_key, |
| result_object_key |
| ) |
| except Exception as spaces_error: |
| logger.warning("Failed to upload to DO Spaces, using local compressed URL: %s", str(spaces_error)) |
| else: |
| logger.info("DO Spaces env not fully configured or boto3 missing; using local compressed URL") |
|
|
| response_data = { |
| "success": True, |
| "result_id": result_id_clean, |
| "download_url": f"{base_url}/results/{result_id}", |
| "api_download_url": f"{base_url}/download/{result_id_clean}", |
| "filename": result_id, |
| "caption": caption, |
| "Compressed_Image_URL": compressed_image_url |
| } |
| |
| |
| log_colorization( |
| result_id=result_id_clean, |
| model_type=model_type_for_log, |
| processing_time=processing_time, |
| user_id=effective_user_id, |
| ip_address=ip_address, |
| status="success", |
| appname=effective_appname |
| ) |
| |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=200, |
| request_data={"filename": file.filename, "content_type": file.content_type, "model": model}, |
| response_data=response_data, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
|
|
| return response_data |
| except Exception as e: |
| error_msg = str(e) |
| logger.error("Error colorizing image: %s", error_msg) |
| |
| |
| log_colorization( |
| result_id=None, |
| model_type=model_type_for_log, |
| processing_time=(time.time() - start_time), |
| user_id=effective_user_id, |
| ip_address=ip_address, |
| status="fail", |
| error=error_msg, |
| appname=effective_appname |
| ) |
| |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=500, |
| error=error_msg, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}") |
|
|
| |
| |
| |
| @app.get("/download/{file_id}") |
| def download_result( |
| request: Request, |
| file_id: str, |
| x_firebase_appcheck: str = Header(None) |
| ): |
| verify_app_check_token(x_firebase_appcheck) |
| |
| ip_address = request.client.host if request.client else None |
|
|
| |
| filename_png = f"{file_id}.png" |
| filename_jpg = f"{file_id}.jpg" |
| path_png = os.path.join(RESULTS_DIR, filename_png) |
| path_jpg = os.path.join(RESULTS_DIR, filename_jpg) |
| |
| |
| if os.path.exists(path_png): |
| filename = filename_png |
| path = path_png |
| media_type = "image/png" |
| elif os.path.exists(path_jpg): |
| filename = filename_jpg |
| path = path_jpg |
| media_type = "image/jpeg" |
| else: |
| filename = filename_png |
| path = path_png |
| media_type = "image/png" |
|
|
| if not os.path.exists(path): |
| log_api_call( |
| endpoint=f"/download/{file_id}", |
| method="GET", |
| status_code=404, |
| error="Result not found", |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=404, detail="Result not found") |
|
|
| log_api_call( |
| endpoint=f"/download/{file_id}", |
| method="GET", |
| status_code=200, |
| request_data={"file_id": file_id}, |
| ip_address=ip_address |
| ) |
|
|
| return FileResponse(path, media_type=media_type) |
|
|
| |
| |
| |
| @app.get("/results/{filename}") |
| def get_result(request: Request, filename: str): |
| ip_address = request.client.host if request.client else None |
| |
| path = os.path.join(RESULTS_DIR, filename) |
| if not os.path.exists(path): |
| log_api_call( |
| endpoint=f"/results/{filename}", |
| method="GET", |
| status_code=404, |
| error="Result not found", |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=404, detail="Result not found") |
| |
| |
| if filename.lower().endswith('.png'): |
| media_type = "image/png" |
| elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'): |
| media_type = "image/jpeg" |
| else: |
| media_type = "image/png" |
| |
| log_api_call( |
| endpoint=f"/results/{filename}", |
| method="GET", |
| status_code=200, |
| request_data={"filename": filename}, |
| ip_address=ip_address |
| ) |
| |
| return FileResponse(path, media_type=media_type) |
|
|
| |
| |
| |
| @app.get("/uploads/{filename}") |
| def get_upload(request: Request, filename: str): |
| ip_address = request.client.host if request.client else None |
| |
| path = os.path.join(UPLOAD_DIR, filename) |
| if not os.path.exists(path): |
| log_api_call( |
| endpoint=f"/uploads/{filename}", |
| method="GET", |
| status_code=404, |
| error="File not found", |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=404, detail="File not found") |
| |
| |
| if filename.lower().endswith('.png'): |
| media_type = "image/png" |
| elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'): |
| media_type = "image/jpeg" |
| else: |
| media_type = "image/jpeg" |
| |
| log_api_call( |
| endpoint=f"/uploads/{filename}", |
| method="GET", |
| status_code=200, |
| request_data={"filename": filename}, |
| ip_address=ip_address |
| ) |
| |
| return FileResponse(path, media_type=media_type) |
|
|
| |
| |
| |
| @app.get("/compressed/{filename}") |
| def get_compressed(request: Request, filename: str): |
| ip_address = request.client.host if request.client else None |
| |
| path = os.path.join(COMPRESSED_DIR, filename) |
| if not os.path.exists(path): |
| log_api_call( |
| endpoint=f"/compressed/{filename}", |
| method="GET", |
| status_code=404, |
| error="Compressed file not found", |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=404, detail="Compressed file not found") |
| |
| |
| media_type = "image/jpeg" |
| |
| log_api_call( |
| endpoint=f"/compressed/{filename}", |
| method="GET", |
| status_code=200, |
| request_data={"filename": filename}, |
| ip_address=ip_address |
| ) |
| |
| return FileResponse(path, media_type=media_type) |
| |
| @app.get("/") |
| async def root(): |
| """Root endpoint""" |
| return { |
| "success": True, |
| "message": "Image Colorization API", |
| "data": { |
| "version": "1.0.0", |
| "Product Name":"Beauty Camera - GlowCam AI Studio", |
| "Released By" : "LogicGo Infotech" |
| } |
| } |
|
|
| |
| |
| |
|
|
|
|
| """ |
| FastAPI application for Text-Guided Image Colorization using Hugging Face Inference API |
| Uses fal-ai provider for memory-efficient inference |
| """ |
| import os |
| import io |
| import uuid |
| import logging |
| from pathlib import Path |
| from typing import Optional, Tuple |
|
|
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request, Body, Form |
| from fastapi.responses import FileResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| import firebase_admin |
| from firebase_admin import credentials, app_check, auth as firebase_auth |
| from PIL import Image |
| import uvicorn |
| import gradio as gr |
| import httpx |
| from pydantic import BaseModel, EmailStr |
|
|
| |
| from huggingface_hub import InferenceClient |
|
|
| from app.config import settings |
| from app.database import ( |
| get_database, |
| log_api_call, |
| log_image_upload, |
| log_colorization, |
| log_media_click, |
| close_connection, |
| ) |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True) |
| Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True) |
| Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True) |
| Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True) |
|
|
| |
| app = FastAPI( |
| title="Text-Guided Image Colorization API", |
| description="Image colorization using SDXL + ControlNet with automatic captioning", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| firebase_cred_paths = [ |
| os.getenv("FIREBASE_CREDENTIALS_PATH"), |
| "/tmp/firebase-adminsdk.json", |
| "/data/firebase-adminsdk.json", |
| "colorize-662df-firebase-adminsdk-fbsvc-bfd21c77c6.json", |
| os.path.join(os.path.dirname(__file__), "..", "colorize-662df-firebase-adminsdk-fbsvc-bfd21c77c6.json"), |
| ] |
|
|
| firebase_initialized = False |
| for cred_path in firebase_cred_paths: |
| if not cred_path: |
| continue |
| cred_path = os.path.abspath(cred_path) |
| if os.path.exists(cred_path): |
| try: |
| cred = credentials.Certificate(cred_path) |
| firebase_admin.initialize_app(cred) |
| logger.info("Firebase Admin SDK initialized from: %s", cred_path) |
| firebase_initialized = True |
| break |
| except Exception as e: |
| logger.warning("Failed to initialize Firebase from %s: %s", cred_path, str(e)) |
| continue |
|
|
| |
| if not firebase_initialized: |
| firebase_json = os.getenv("FIREBASE_CREDENTIALS") |
| if firebase_json: |
| try: |
| import json |
| firebase_dict = json.loads(firebase_json) |
| cred = credentials.Certificate(firebase_dict) |
| firebase_admin.initialize_app(cred) |
| logger.info("Firebase Admin SDK initialized from environment variable") |
| firebase_initialized = True |
| except Exception as e: |
| logger.warning("Failed to initialize Firebase from environment: %s", str(e)) |
|
|
| if not firebase_initialized: |
| logger.warning("Firebase credentials file not found. App Check will be disabled.") |
| try: |
| firebase_admin.initialize_app() |
| except: |
| pass |
|
|
| |
| UPLOAD_DIR = Path("/tmp/colorize_uploads") |
| RESULT_DIR = Path("/tmp/colorize_results") |
|
|
| |
| app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results") |
| app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") |
|
|
| |
| inference_client = None |
| model_load_error: Optional[str] = None |
|
|
| |
|
|
| def apply_color(image: Image.Image, color_map: Image.Image) -> Image.Image: |
| """Apply color from color_map to image using LAB color space.""" |
| |
| image_lab = image.convert('LAB') |
| color_map_lab = color_map.convert('LAB') |
|
|
| |
| l, _, _ = image_lab.split() |
| _, a_map, b_map = color_map_lab.split() |
| merged_lab = Image.merge('LAB', (l, a_map, b_map)) |
|
|
| return merged_lab.convert('RGB') |
|
|
|
|
| def remove_unlikely_words(prompt: str) -> str: |
| """Removes predefined unlikely phrases from prompt text.""" |
| unlikely_words = [] |
|
|
| a1 = [f'{i}s' for i in range(1900, 2000)] |
| a2 = [f'{i}' for i in range(1900, 2000)] |
| a3 = [f'year {i}' for i in range(1900, 2000)] |
| a4 = [f'circa {i}' for i in range(1900, 2000)] |
|
|
| b1 = [f"{y[0]} {y[1]} {y[2]} {y[3]} s" for y in a1] |
| b2 = [f"{y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] |
| b3 = [f"year {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] |
| b4 = [f"circa {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1] |
|
|
| manual = [ |
| "black and white,", "black and white", "black & white,", "black & white", "circa", |
| "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", |
| "black - and - white photography,", "monochrome bw,", "black white,", "black an white,", |
| "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", |
| "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", |
| "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", |
| "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", |
| "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", |
| "black-and-white photo,", "black-and-white photo", "black - and - white photography", |
| "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", |
| "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", |
| "black - and - white photograph,", "black - and - white photograph", "black on white,", |
| "black on white", "black-and-white", "historical image,", "historical picture,", |
| "historical photo,", "historical photograph,", "archival photo,", "taken in the early", |
| "taken in the late", "taken in the", "historic photograph,", "restored,", "restored", |
| "historical photo", "historical setting,", |
| "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", |
| "taken in", "shot on leica", "shot on leica sl2", "sl2", |
| "taken with a leica camera", "leica sl2", "leica", "setting", |
| "overcast day", "overcast weather", "slight overcast", "overcast", |
| "picture taken in", "photo taken in", |
| ", photo", ", photo", ", photo", ", photo", ", photograph", |
| ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", |
| ] |
|
|
| unlikely_words.extend(a1 + a2 + a3 + a4 + b1 + b2 + b3 + b4 + manual) |
|
|
| for word in unlikely_words: |
| prompt = prompt.replace(word, "") |
| return prompt |
|
|
|
|
| |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize Hugging Face Inference API client and MongoDB""" |
| global inference_client, model_load_error |
| |
| |
| try: |
| db = get_database() |
| if db is not None: |
| logger.info("✅ MongoDB initialized successfully!") |
| except Exception as e: |
| logger.warning("⚠️ MongoDB initialization failed: %s", str(e)) |
| |
| try: |
| logger.info("🔄 Initializing Hugging Face Inference API client...") |
| |
| |
| hf_token = os.getenv("HF_TOKEN") or settings.HF_TOKEN |
| if not hf_token: |
| raise ValueError("HF_TOKEN environment variable is required for Inference API") |
| |
| |
| inference_client = InferenceClient( |
| provider="fal-ai", |
| api_key=hf_token, |
| ) |
| |
| logger.info("✅ Inference API client initialized successfully!") |
| model_load_error = None |
| |
| except Exception as e: |
| error_msg = str(e) |
| logger.error(f"❌ Failed to initialize Inference API client: {error_msg}") |
| model_load_error = error_msg |
| |
|
|
|
|
| @app.on_event("shutdown") |
| async def shutdown_event(): |
| """Cleanup on shutdown""" |
| global inference_client |
| if inference_client: |
| inference_client = None |
| close_connection() |
| logger.info("Application shutdown") |
|
|
|
|
| |
|
|
| class RegisterRequest(BaseModel): |
| email: EmailStr |
| password: str |
| display_name: Optional[str] = None |
|
|
| class LoginRequest(BaseModel): |
| email: EmailStr |
| password: str |
|
|
| class TokenResponse(BaseModel): |
| id_token: str |
| refresh_token: Optional[str] = None |
| expires_in: int |
| token_type: str = "Bearer" |
| user: dict |
|
|
| |
|
|
| def _extract_bearer_token(authorization_header: str | None) -> str | None: |
| if not authorization_header: |
| return None |
| parts = authorization_header.split(" ", 1) |
| if len(parts) == 2 and parts[0].lower() == "bearer": |
| return parts[1].strip() |
| return None |
|
|
|
|
| async def verify_request(request: Request): |
| """ |
| Verify Firebase authentication. |
| Priority: |
| 1. Firebase App Check token (X-Firebase-AppCheck header) - Primary method per documentation |
| 2. Firebase Auth ID token (Authorization: Bearer header) - Fallback for auth endpoints |
| """ |
| if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true": |
| return True |
|
|
| |
| app_check_token = request.headers.get("X-Firebase-AppCheck") |
| if app_check_token: |
| try: |
| app_check_claims = app_check.verify_token(app_check_token) |
| logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) |
| return True |
| except Exception as e: |
| logger.warning("App Check token verification failed: %s", str(e)) |
| if settings.ENABLE_APP_CHECK: |
| raise HTTPException(status_code=401, detail="Invalid App Check token") |
|
|
| |
| |
| bearer = _extract_bearer_token(request.headers.get("Authorization")) |
| if bearer: |
| try: |
| decoded = firebase_auth.verify_id_token(bearer) |
| request.state.user = decoded |
| logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) |
| return True |
| except Exception as e: |
| logger.warning("Auth token verification failed: %s", str(e)) |
|
|
| |
| if settings.ENABLE_APP_CHECK: |
| if not app_check_token: |
| raise HTTPException(status_code=401, detail="Missing App Check token") |
| raise HTTPException(status_code=401, detail="Invalid App Check token") |
|
|
| |
| return True |
|
|
|
|
| def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]: |
| """Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click).""" |
| if supplied_user_id and supplied_user_id.strip(): |
| return supplied_user_id.strip() |
| return None |
|
|
|
|
| |
|
|
| @app.post("/auth/register", response_model=TokenResponse) |
| async def register_user(user_data: RegisterRequest): |
| """ |
| Register a new user with email and password. |
| Returns Firebase ID token for immediate use. |
| """ |
| if not firebase_admin._apps: |
| raise HTTPException(status_code=503, detail="Firebase not initialized") |
| |
| try: |
| |
| user_record = firebase_auth.create_user( |
| email=user_data.email, |
| password=user_data.password, |
| display_name=user_data.display_name, |
| email_verified=False |
| ) |
| |
| |
| custom_token = firebase_auth.create_custom_token(user_record.uid) |
| |
| logger.info("User registered: %s (uid: %s)", user_data.email, user_record.uid) |
| |
| return TokenResponse( |
| id_token=custom_token.decode('utf-8'), |
| token_type="Bearer", |
| expires_in=3600, |
| user={ |
| "uid": user_record.uid, |
| "email": user_record.email, |
| "display_name": user_record.display_name, |
| "email_verified": user_record.email_verified |
| } |
| ) |
| except firebase_auth.EmailAlreadyExistsError: |
| raise HTTPException(status_code=400, detail="Email already registered") |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=f"Invalid input: {str(e)}") |
| except Exception as e: |
| logger.error("Registration error: %s", str(e)) |
| raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}") |
|
|
|
|
| @app.post("/auth/login", response_model=TokenResponse) |
| async def login_user(credentials: LoginRequest): |
| """ |
| Login with email and password. |
| Uses Firebase REST API to authenticate and get ID token. |
| """ |
| if not firebase_admin._apps: |
| raise HTTPException(status_code=503, detail="Firebase not initialized") |
| |
| |
| firebase_api_key = os.getenv("FIREBASE_API_KEY") or settings.FIREBASE_API_KEY |
| if not firebase_api_key: |
| |
| try: |
| user_record = firebase_auth.get_user_by_email(credentials.email) |
| custom_token = firebase_auth.create_custom_token(user_record.uid) |
| |
| logger.info("User login: %s (uid: %s)", credentials.email, user_record.uid) |
| |
| return TokenResponse( |
| id_token=custom_token.decode('utf-8'), |
| token_type="Bearer", |
| expires_in=3600, |
| user={ |
| "uid": user_record.uid, |
| "email": user_record.email, |
| "display_name": user_record.display_name, |
| "email_verified": user_record.email_verified |
| } |
| ) |
| except firebase_auth.UserNotFoundError: |
| raise HTTPException(status_code=401, detail="Invalid email or password") |
| except Exception as e: |
| logger.error("Login error: %s", str(e)) |
| raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}") |
| |
| |
| try: |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={firebase_api_key}", |
| json={ |
| "email": credentials.email, |
| "password": credentials.password, |
| "returnSecureToken": True |
| } |
| ) |
| |
| if response.status_code != 200: |
| error_data = response.json() |
| error_msg = error_data.get("error", {}).get("message", "Authentication failed") |
| raise HTTPException(status_code=401, detail=error_msg) |
| |
| data = response.json() |
| logger.info("User login successful: %s", credentials.email) |
| |
| |
| user_record = firebase_auth.get_user(data["localId"]) |
| |
| return TokenResponse( |
| id_token=data["idToken"], |
| refresh_token=data.get("refreshToken"), |
| expires_in=int(data.get("expiresIn", 3600)), |
| token_type="Bearer", |
| user={ |
| "uid": user_record.uid, |
| "email": user_record.email, |
| "display_name": user_record.display_name, |
| "email_verified": user_record.email_verified |
| } |
| ) |
| except httpx.HTTPError as e: |
| logger.error("HTTP error during login: %s", str(e)) |
| raise HTTPException(status_code=500, detail="Authentication service unavailable") |
| except Exception as e: |
| logger.error("Login error: %s", str(e)) |
| raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}") |
|
|
|
|
| @app.get("/auth/me") |
| async def get_current_user(request: Request, verified: bool = Depends(verify_request)): |
| """Get current authenticated user information""" |
| if not firebase_admin._apps: |
| raise HTTPException(status_code=503, detail="Firebase not initialized") |
| |
| |
| if hasattr(request, 'state') and hasattr(request.state, 'user'): |
| user_data = request.state.user |
| uid = user_data.get("uid") |
| |
| try: |
| user_record = firebase_auth.get_user(uid) |
| return { |
| "uid": user_record.uid, |
| "email": user_record.email, |
| "display_name": user_record.display_name, |
| "email_verified": user_record.email_verified, |
| "created_at": user_record.user_metadata.creation_timestamp, |
| } |
| except Exception as e: |
| logger.error("Error getting user: %s", str(e)) |
| raise HTTPException(status_code=404, detail="User not found") |
| |
| raise HTTPException(status_code=401, detail="Not authenticated") |
|
|
|
|
| @app.post("/auth/refresh") |
| async def refresh_token(refresh_token: str = Body(..., embed=True)): |
| """Refresh Firebase ID token using refresh token""" |
| firebase_api_key = os.getenv("FIREBASE_API_KEY") or settings.FIREBASE_API_KEY |
| if not firebase_api_key: |
| raise HTTPException(status_code=503, detail="Firebase API key not configured") |
| |
| try: |
| async with httpx.AsyncClient() as client: |
| response = await client.post( |
| f"https://securetoken.googleapis.com/v1/token?key={firebase_api_key}", |
| json={ |
| "grant_type": "refresh_token", |
| "refresh_token": refresh_token |
| } |
| ) |
| |
| if response.status_code != 200: |
| error_data = response.json() |
| error_msg = error_data.get("error", {}).get("message", "Token refresh failed") |
| raise HTTPException(status_code=401, detail=error_msg) |
| |
| data = response.json() |
| return { |
| "id_token": data["id_token"], |
| "refresh_token": data.get("refresh_token"), |
| "expires_in": int(data.get("expires_in", 3600)), |
| "token_type": "Bearer" |
| } |
| except httpx.HTTPError as e: |
| logger.error("HTTP error during token refresh: %s", str(e)) |
| raise HTTPException(status_code=500, detail="Token refresh service unavailable") |
| except Exception as e: |
| logger.error("Token refresh error: %s", str(e)) |
| raise HTTPException(status_code=500, detail=f"Token refresh failed: {str(e)}") |
|
|
|
|
| |
|
|
| @app.get("/api") |
| async def api_info(request: Request): |
| """API info endpoint""" |
| response_data = { |
| "app": "Text-Guided Image Colorization API", |
| "version": "1.0.0", |
| "endpoints": { |
| "health": "/health", |
| "upload": "/upload", |
| "colorize": "/colorize", |
| "download": "/download/{file_id}", |
| "results": "/results/{filename}", |
| "uploads": "/uploads/{filename}", |
| "auth": { |
| "register": "/auth/register", |
| "login": "/auth/login", |
| "me": "/auth/me", |
| "refresh": "/auth/refresh" |
| }, |
| "gradio": "/" |
| } |
| } |
| |
| |
| user_id = None |
| if hasattr(request, 'state') and hasattr(request.state, 'user'): |
| user_id = request.state.user.get("uid") |
| |
| log_api_call( |
| endpoint="/api", |
| method="GET", |
| status_code=200, |
| response_data=response_data, |
| user_id=user_id, |
| ip_address=request.client.host if request.client else None |
| ) |
| |
| return response_data |
|
|
|
|
| @app.get("/health") |
| async def health_check(request: Request): |
| """Health check endpoint""" |
| response = { |
| "status": "healthy", |
| "model_loaded": inference_client is not None, |
| "model_type": "hf_inference_api", |
| "provider": "fal-ai" |
| } |
| if model_load_error: |
| response["model_error"] = model_load_error |
| |
| |
| log_api_call( |
| endpoint="/health", |
| method="GET", |
| status_code=200, |
| response_data=response, |
| ip_address=request.client.host if request.client else None |
| ) |
| |
| return response |
|
|
|
|
| def colorize_image_sdxl( |
| image: Image.Image, |
| positive_prompt: Optional[str] = None, |
| negative_prompt: Optional[str] = None, |
| seed: int = 123, |
| num_inference_steps: int = 8 |
| ) -> Tuple[Image.Image, str]: |
| """ |
| Colorize a grayscale or low-color image using Hugging Face Inference API. |
| |
| Args: |
| image: PIL Image to colorize |
| positive_prompt: Additional descriptive text to enhance the caption |
| negative_prompt: Words or phrases to avoid during generation |
| seed: Random seed for reproducible generation |
| num_inference_steps: Number of inference steps |
| |
| Returns: |
| Tuple of (colorized PIL Image, caption string) |
| """ |
| if inference_client is None: |
| raise RuntimeError("Inference API client not initialized") |
| |
| original_size = image.size |
| |
| control_image = image.convert("RGB").resize((512, 512)) |
| |
| |
| img_bytes = io.BytesIO() |
| control_image.save(img_bytes, format="PNG") |
| img_bytes.seek(0) |
| input_image = img_bytes.read() |
| |
| |
| base_prompt = positive_prompt or "colorize this image with vibrant natural colors, high quality" |
| if negative_prompt: |
| |
| final_prompt = f"{base_prompt}. Avoid: {negative_prompt}" |
| else: |
| final_prompt = base_prompt |
| |
| |
| model_name = settings.INFERENCE_MODEL |
| logger.info(f"Calling Inference API with model {model_name}, prompt: {final_prompt}") |
| try: |
| result_image = inference_client.image_to_image( |
| input_image, |
| prompt=final_prompt, |
| model=model_name, |
| ) |
| |
| |
| if isinstance(result_image, Image.Image): |
| colorized = result_image.resize(original_size) |
| else: |
| |
| colorized = Image.open(io.BytesIO(result_image)).resize(original_size) |
| |
| |
| caption = final_prompt[:100] |
| |
| return colorized, caption |
| |
| except Exception as e: |
| logger.error(f"Inference API error: {e}") |
| raise RuntimeError(f"Failed to colorize image: {str(e)}") |
|
|
|
|
| @app.post("/upload") |
| async def upload_image( |
| request: Request, |
| file: UploadFile = File(...), |
| user_id: Optional[str] = Form(None), |
| category_id: Optional[str] = Form(None), |
| categoryId: Optional[str] = Form(None), |
| verified: bool = Depends(verify_request) |
| ): |
| """ |
| Upload an image and get the uploaded image URL. |
| Requires Firebase App Check authentication. |
| """ |
| effective_user_id = _resolve_user_id(request, user_id) |
| effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
| if effective_category_id: |
| effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
| if not effective_category_id: |
| effective_category_id = None |
| |
| ip_address = request.client.host if request.client else None |
| |
| if not file.content_type or not file.content_type.startswith("image/"): |
| log_api_call( |
| endpoint="/upload", |
| method="POST", |
| status_code=400, |
| error="File must be an image", |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=400, detail="File must be an image") |
| |
| try: |
| |
| file_extension = file.filename.split('.')[-1] if file.filename else 'jpg' |
| image_id = f"{uuid.uuid4()}.{file_extension}" |
| file_path = UPLOAD_DIR / image_id |
| |
| |
| img_bytes = await file.read() |
| file_size = len(img_bytes) |
| with open(file_path, "wb") as f: |
| f.write(img_bytes) |
| |
| logger.info("Image uploaded: %s", image_id) |
| |
| |
| base_url = os.getenv("BASE_URL", settings.BASE_URL) |
| if not base_url or base_url == "http://localhost:8000": |
| |
| base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
| |
| response_data = { |
| "success": True, |
| "image_id": image_id.replace(f".{file_extension}", ""), |
| "image_url": f"{base_url}/uploads/{image_id}", |
| "filename": image_id |
| } |
| |
| |
| log_image_upload( |
| image_id=image_id.replace(f".{file_extension}", ""), |
| filename=file.filename or image_id, |
| file_size=file_size, |
| content_type=file.content_type or "image/jpeg", |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| |
| log_api_call( |
| endpoint="/upload", |
| method="POST", |
| status_code=200, |
| request_data={"filename": file.filename, "content_type": file.content_type}, |
| response_data=response_data, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| |
| return JSONResponse(response_data) |
| except Exception as e: |
| error_msg = str(e) |
| logger.error("Error uploading image: %s", error_msg) |
| log_api_call( |
| endpoint="/upload", |
| method="POST", |
| status_code=500, |
| error=error_msg, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=500, detail=f"Error uploading image: {error_msg}") |
|
|
|
|
| @app.post("/colorize") |
| async def colorize_api( |
| request: Request, |
| file: UploadFile = File(...), |
| positive_prompt: Optional[str] = None, |
| negative_prompt: Optional[str] = None, |
| seed: int = 123, |
| num_inference_steps: int = 8, |
| user_id: Optional[str] = Form(None), |
| category_id: Optional[str] = Form(None), |
| categoryId: Optional[str] = Form(None), |
| verified: bool = Depends(verify_request) |
| ): |
| """ |
| Upload a grayscale image -> returns colorized image. |
| Uses SDXL + ControlNet with automatic captioning. |
| """ |
| import time |
| start_time = time.time() |
| |
| effective_user_id = _resolve_user_id(request, user_id) |
| effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
| if effective_category_id: |
| effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
| if not effective_category_id: |
| effective_category_id = None |
| |
| ip_address = request.client.host if request.client else None |
| |
| if inference_client is None: |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=503, |
| error="Inference API client not initialized", |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=503, detail="Inference API client not initialized") |
| |
| if not file.content_type or not file.content_type.startswith("image/"): |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=400, |
| error="File must be an image", |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=400, detail="File must be an image") |
| |
| try: |
| img_bytes = await file.read() |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
| |
| logger.info("Colorizing image with SDXL + ControlNet...") |
| colorized, caption = colorize_image_sdxl( |
| image, |
| positive_prompt=positive_prompt, |
| negative_prompt=negative_prompt, |
| seed=seed, |
| num_inference_steps=num_inference_steps |
| ) |
| |
| processing_time = time.time() - start_time |
| |
| output_filename = f"{uuid.uuid4()}.png" |
| output_path = RESULT_DIR / output_filename |
| colorized.save(output_path, "PNG") |
| |
| logger.info("Colorized image saved: %s", output_filename) |
| |
| |
| base_url = os.getenv("BASE_URL", settings.BASE_URL) |
| if not base_url or base_url == "http://localhost:8000": |
| base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space" |
| |
| result_id = output_filename.replace(".png", "") |
| |
| response_data = { |
| "success": True, |
| "result_id": result_id, |
| "download_url": f"{base_url}/results/{output_filename}", |
| "api_download_url": f"{base_url}/download/{result_id}", |
| "filename": output_filename, |
| "caption": caption |
| } |
| |
| |
| log_colorization( |
| result_id=result_id, |
| prompt=positive_prompt, |
| model_type="sdxl", |
| processing_time=processing_time, |
| user_id=effective_user_id, |
| ip_address=ip_address, |
| status="success" |
| ) |
| |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=200, |
| request_data={ |
| "filename": file.filename, |
| "positive_prompt": positive_prompt, |
| "negative_prompt": negative_prompt, |
| "seed": seed, |
| "num_inference_steps": num_inference_steps |
| }, |
| response_data=response_data, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| |
| return JSONResponse(response_data) |
| except Exception as e: |
| error_msg = str(e) |
| logger.error("Error colorizing image: %s", error_msg) |
| |
| |
| log_colorization( |
| result_id=None, |
| prompt=positive_prompt, |
| model_type="sdxl", |
| processing_time=None, |
| user_id=effective_user_id, |
| ip_address=ip_address, |
| status="failed", |
| error=error_msg |
| ) |
| |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=500, |
| error=error_msg, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}") |
|
|
|
|
| @app.get("/download/{file_id}") |
| def download_result( |
| request: Request, |
| file_id: str, |
| verified: bool = Depends(verify_request) |
| ): |
| """Download colorized image by file ID""" |
| user_id = None |
| if hasattr(request, 'state') and hasattr(request.state, 'user'): |
| user_id = request.state.user.get("uid") |
| |
| ip_address = request.client.host if request.client else None |
| |
| filename = f"{file_id}.png" |
| path = RESULT_DIR / filename |
| |
| if not path.exists(): |
| log_api_call( |
| endpoint=f"/download/{file_id}", |
| method="GET", |
| status_code=404, |
| error="Result not found", |
| user_id=user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=404, detail="Result not found") |
| |
| log_api_call( |
| endpoint=f"/download/{file_id}", |
| method="GET", |
| status_code=200, |
| request_data={"file_id": file_id}, |
| user_id=user_id, |
| ip_address=ip_address |
| ) |
| |
| return FileResponse(path, media_type="image/png") |
|
|
|
|
| @app.get("/results/{filename}") |
| def get_result(request: Request, filename: str): |
| """Public endpoint to access colorized images""" |
| ip_address = request.client.host if request.client else None |
| |
| path = RESULT_DIR / filename |
| if not path.exists(): |
| log_api_call( |
| endpoint=f"/results/{filename}", |
| method="GET", |
| status_code=404, |
| error="Result not found", |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=404, detail="Result not found") |
| |
| log_api_call( |
| endpoint=f"/results/{filename}", |
| method="GET", |
| status_code=200, |
| request_data={"filename": filename}, |
| ip_address=ip_address |
| ) |
| |
| return FileResponse(path, media_type="image/png") |
|
|
|
|
| |
|
|
| def gradio_colorize(image, positive_prompt=None, negative_prompt=None, seed=123): |
| """Gradio colorization function""" |
| if image is None: |
| return None, "" |
| try: |
| if inference_client is None: |
| return None, "Inference API client not initialized" |
| colorized, caption = colorize_image_sdxl( |
| image, |
| positive_prompt=positive_prompt, |
| negative_prompt=negative_prompt, |
| seed=seed |
| ) |
| return colorized, caption |
| except Exception as e: |
| logger.error("Gradio colorization error: %s", str(e)) |
| return None, str(e) |
|
|
|
|
| title = "🎨 Text-Guided Image Colorization" |
| description = "Upload a grayscale image and generate a color version using Hugging Face Inference API (fal-ai provider)." |
|
|
| iface = gr.Interface( |
| fn=gradio_colorize, |
| inputs=[ |
| gr.Image(type="pil", label="Upload Image"), |
| gr.Textbox(label="Positive Prompt", placeholder="Enter details to enhance the caption"), |
| gr.Textbox(label="Negative Prompt", value=settings.NEGATIVE_PROMPT), |
| gr.Slider(0, 1000, 123, label="Seed") |
| ], |
| outputs=[ |
| gr.Image(type="pil", label="Colorized Image"), |
| gr.Textbox(label="Caption") |
| ], |
| title=title, |
| description=description, |
| ) |
|
|
| |
| app = gr.mount_gradio_app(app, iface, path="/") |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", "7860")) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|
|
|
| |
| |
| |
|
|
| """ |
| FastAPI application for FastAI GAN Image Colorization |
| with Firebase Authentication and Gradio UI |
| """ |
| import os |
| |
| os.environ["OMP_NUM_THREADS"] = "1" |
| os.environ["HF_HOME"] = "/tmp/hf_cache" |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" |
| os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache" |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache" |
| os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache" |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config" |
|
|
| import io |
| import uuid |
| import logging |
| from pathlib import Path |
| from typing import Optional |
|
|
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request, Form |
| from fastapi.responses import FileResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
| import firebase_admin |
| from firebase_admin import credentials, app_check, auth as firebase_auth |
| from PIL import Image |
| import torch |
| import uvicorn |
| import gradio as gr |
| import numpy as np |
| import cv2 |
|
|
| |
| from fastai.vision.all import * |
| from huggingface_hub import from_pretrained_fastai |
|
|
| from app.config import settings |
| from app.pytorch_colorizer import PyTorchColorizer |
| from app.database import ( |
| get_database, |
| log_api_call, |
| log_image_upload, |
| log_colorization, |
| log_media_click, |
| close_connection, |
| ) |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True) |
| Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True) |
| Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True) |
| Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True) |
|
|
| |
| app = FastAPI( |
| title="FastAI Image Colorizer API", |
| description="Image colorization using FastAI GAN model with Firebase authentication", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json") |
| if os.path.exists(firebase_cred_path): |
| try: |
| cred = credentials.Certificate(firebase_cred_path) |
| firebase_admin.initialize_app(cred) |
| logger.info("Firebase Admin SDK initialized") |
| except Exception as e: |
| logger.warning("Failed to initialize Firebase: %s", str(e)) |
| try: |
| firebase_admin.initialize_app() |
| except: |
| pass |
| else: |
| logger.warning("Firebase credentials file not found. App Check will be disabled.") |
| try: |
| firebase_admin.initialize_app() |
| except: |
| pass |
|
|
| |
| UPLOAD_DIR = Path("/tmp/colorize_uploads") |
| RESULT_DIR = Path("/tmp/colorize_results") |
|
|
| |
| app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results") |
| app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads") |
|
|
| |
| learn = None |
| pytorch_colorizer = None |
| model_load_error: Optional[str] = None |
| model_type: str = "none" |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| """Load FastAI or PyTorch model on startup and initialize MongoDB""" |
| global learn, pytorch_colorizer, model_load_error, model_type |
| |
| |
| try: |
| db = get_database() |
| if db is not None: |
| logger.info("✅ MongoDB initialized successfully!") |
| except Exception as e: |
| logger.warning("⚠️ MongoDB initialization failed: %s", str(e)) |
| |
| model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model") |
| |
| |
| try: |
| logger.info("🔄 Attempting to load FastAI GAN Colorization Model: %s", model_id) |
| learn = from_pretrained_fastai(model_id) |
| logger.info("✅ FastAI model loaded successfully!") |
| model_type = "fastai" |
| model_load_error = None |
| return |
| except Exception as e: |
| error_msg = str(e) |
| logger.warning("⚠️ FastAI model loading failed: %s. Trying PyTorch fallback...", error_msg) |
| |
| |
| try: |
| logger.info("🔄 Attempting to load PyTorch GAN Colorization Model: %s", model_id) |
| pytorch_colorizer = PyTorchColorizer(model_id=model_id, model_filename="generator.pt") |
| logger.info("✅ PyTorch model loaded successfully!") |
| model_type = "pytorch" |
| model_load_error = None |
| except Exception as e: |
| error_msg = str(e) |
| logger.error("❌ Failed to load both FastAI and PyTorch models: %s", error_msg) |
| model_load_error = error_msg |
| model_type = "none" |
| |
|
|
| @app.on_event("shutdown") |
| async def shutdown_event(): |
| """Cleanup on shutdown""" |
| global learn, pytorch_colorizer |
| if learn: |
| del learn |
| if pytorch_colorizer: |
| del pytorch_colorizer |
| close_connection() |
| logger.info("Application shutdown") |
|
|
| def _extract_bearer_token(authorization_header: str | None) -> str | None: |
| if not authorization_header: |
| return None |
| parts = authorization_header.split(" ", 1) |
| if len(parts) == 2 and parts[0].lower() == "bearer": |
| return parts[1].strip() |
| return None |
|
|
| async def verify_request(request: Request): |
| """ |
| Verify Firebase authentication |
| Accept either: |
| - Firebase Auth id_token via Authorization: Bearer <id_token> |
| - Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true) |
| """ |
| |
| if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true": |
| return True |
|
|
| |
| bearer = _extract_bearer_token(request.headers.get("Authorization")) |
| if bearer: |
| try: |
| decoded = firebase_auth.verify_id_token(bearer) |
| request.state.user = decoded |
| logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid")) |
| return True |
| except Exception as e: |
| logger.warning("Auth token verification failed: %s", str(e)) |
|
|
| |
| if settings.ENABLE_APP_CHECK: |
| app_check_token = request.headers.get("X-Firebase-AppCheck") |
| if not app_check_token: |
| raise HTTPException(status_code=401, detail="Missing App Check token") |
| try: |
| app_check_claims = app_check.verify_token(app_check_token) |
| logger.info("App Check token verified for: %s", app_check_claims.get("app_id")) |
| return True |
| except Exception as e: |
| logger.warning("App Check token verification failed: %s", str(e)) |
| raise HTTPException(status_code=401, detail="Invalid App Check token") |
|
|
| |
| return True |
|
|
|
|
| def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]: |
| """Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click).""" |
| if supplied_user_id and supplied_user_id.strip(): |
| return supplied_user_id.strip() |
| return None |
|
|
| @app.get("/api") |
| async def api_info(request: Request): |
| """API info endpoint""" |
| response_data = { |
| "app": "FastAI Image Colorizer API", |
| "version": "1.0.0", |
| "health": "/health", |
| "colorize": "/colorize", |
| "gradio": "/" |
| } |
| |
| |
| user_id = None |
| if hasattr(request, 'state') and hasattr(request.state, 'user'): |
| user_id = request.state.user.get("uid") |
| |
| log_api_call( |
| endpoint="/api", |
| method="GET", |
| status_code=200, |
| response_data=response_data, |
| user_id=user_id, |
| ip_address=request.client.host if request.client else None |
| ) |
| |
| return response_data |
|
|
| @app.get("/health") |
| async def health_check(request: Request): |
| """Health check endpoint""" |
| model_loaded = (learn is not None) or (pytorch_colorizer is not None) |
| response = { |
| "status": "healthy", |
| "model_loaded": model_loaded, |
| "model_type": model_type, |
| "model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model"), |
| "using_fallback": not model_loaded |
| } |
| if model_load_error: |
| response["model_error"] = model_load_error |
| response["message"] = "Model failed to load. Using fallback colorization method." |
| elif not model_loaded: |
| response["message"] = "No model loaded. Using fallback colorization method." |
| else: |
| response["message"] = f"Model loaded successfully ({model_type})" |
| |
| |
| log_api_call( |
| endpoint="/health", |
| method="GET", |
| status_code=200, |
| response_data=response, |
| ip_address=request.client.host if request.client else None |
| ) |
| |
| return response |
|
|
| def simple_colorize_fallback(image: Image.Image) -> Image.Image: |
| """ |
| Enhanced fallback colorization using LAB color space with better color hints |
| This provides basic colorization when the model doesn't load |
| Note: This is a simple heuristic-based approach and won't match trained models |
| """ |
| |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
| |
| |
| img_array = np.array(image) |
| original_shape = img_array.shape |
| |
| |
| lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB) |
| |
| |
| l, a, b = cv2.split(lab) |
| |
| |
| clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) |
| l_enhanced = clahe.apply(l) |
| |
| |
| |
| l_normalized = l.astype(np.float32) / 255.0 |
| |
| |
| |
| |
| |
| |
| |
| brightness_mask = np.clip((l_normalized - 0.3) * 2, 0, 1) |
| |
| |
| a_hint = np.clip(a.astype(np.float32) + brightness_mask * 8 + (1 - brightness_mask) * 2, 0, 255).astype(np.uint8) |
| b_hint = np.clip(b.astype(np.float32) + brightness_mask * 12 + (1 - brightness_mask) * 3, 0, 255).astype(np.uint8) |
| |
| |
| lab_colored = cv2.merge([l_enhanced, a_hint, b_hint]) |
| colored_rgb = cv2.cvtColor(lab_colored, cv2.COLOR_LAB2RGB) |
| |
| |
| hsv = cv2.cvtColor(colored_rgb, cv2.COLOR_RGB2HSV) |
| hsv[:, :, 1] = np.clip(hsv[:, :, 1].astype(np.float32) * 1.2, 0, 255).astype(np.uint8) |
| colored_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) |
| |
| return Image.fromarray(colored_rgb) |
|
|
|
|
| def colorize_pil(image: Image.Image) -> Image.Image: |
| """Run model prediction and return colorized image""" |
| |
| if learn is not None: |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
| pred = learn.predict(image) |
| |
| if isinstance(pred, (list, tuple)): |
| colorized = pred[0] if len(pred) > 0 else image |
| else: |
| colorized = pred |
| |
| |
| if not isinstance(colorized, Image.Image): |
| if isinstance(colorized, torch.Tensor): |
| |
| if colorized.dim() == 4: |
| colorized = colorized[0] |
| if colorized.dim() == 3: |
| colorized = colorized.permute(1, 2, 0).cpu() |
| if colorized.dtype in (torch.float32, torch.float16): |
| colorized = torch.clamp(colorized, 0, 1) |
| colorized = (colorized * 255).byte() |
| colorized = Image.fromarray(colorized.numpy(), 'RGB') |
| else: |
| raise ValueError(f"Unexpected tensor shape: {colorized.shape}") |
| else: |
| raise ValueError(f"Unexpected prediction type: {type(colorized)}") |
| |
| if colorized.mode != "RGB": |
| colorized = colorized.convert("RGB") |
| |
| return colorized |
| |
| |
| elif pytorch_colorizer is not None: |
| return pytorch_colorizer.colorize(image) |
| |
| else: |
| |
| logger.info("No model loaded, using enhanced colorization fallback (LAB color space method)") |
| return simple_colorize_fallback(image) |
|
|
| @app.post("/colorize") |
| async def colorize_api( |
| request: Request, |
| file: UploadFile = File(...), |
| user_id: Optional[str] = Form(None), |
| category_id: Optional[str] = Form(None), |
| categoryId: Optional[str] = Form(None), |
| verified: bool = Depends(verify_request) |
| ): |
| """ |
| Upload a black & white image -> returns colorized image. |
| Requires Firebase authentication unless DISABLE_AUTH=true |
| """ |
| import time |
| start_time = time.time() |
| |
| effective_user_id = _resolve_user_id(request, user_id) |
| effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None |
| if effective_category_id: |
| effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id |
| if not effective_category_id: |
| effective_category_id = None |
| |
| ip_address = request.client.host if request.client else None |
| |
| |
| |
| |
| |
| if not file.content_type or not file.content_type.startswith("image/"): |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=400, |
| error="File must be an image", |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=400, detail="File must be an image") |
| |
| try: |
| img_bytes = await file.read() |
| image = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
| |
| logger.info("Colorizing image...") |
| colorized = colorize_pil(image) |
| |
| processing_time = time.time() - start_time |
| |
| output_filename = f"{uuid.uuid4()}.png" |
| output_path = RESULT_DIR / output_filename |
| colorized.save(output_path, "PNG") |
| |
| logger.info("Colorized image saved: %s", output_filename) |
| |
| result_id = output_filename.replace(".png", "") |
| |
| |
| log_colorization( |
| result_id=result_id, |
| model_type=model_type, |
| processing_time=processing_time, |
| user_id=effective_user_id, |
| ip_address=ip_address, |
| status="success" |
| ) |
| |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=200, |
| request_data={"filename": file.filename, "content_type": file.content_type}, |
| response_data={"result_id": result_id, "filename": output_filename}, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| |
| |
| return FileResponse( |
| output_path, |
| media_type="image/png", |
| filename=f"colorized_{output_filename}" |
| ) |
| except Exception as e: |
| error_msg = str(e) |
| logger.error("Error colorizing image: %s", error_msg) |
| |
| |
| log_colorization( |
| result_id=None, |
| model_type=model_type, |
| processing_time=None, |
| user_id=effective_user_id, |
| ip_address=ip_address, |
| status="failed", |
| error=error_msg |
| ) |
| |
| log_api_call( |
| endpoint="/colorize", |
| method="POST", |
| status_code=500, |
| error=error_msg, |
| user_id=effective_user_id, |
| ip_address=ip_address |
| ) |
| raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}") |
|
|
| |
| |
| |
| def gradio_colorize(image): |
| """Gradio colorization function""" |
| if image is None: |
| return None |
| try: |
| |
| return colorize_pil(image) |
| except Exception as e: |
| logger.error("Gradio colorization error: %s", str(e)) |
| return None |
|
|
| title = "🎨 Image Colorizer" |
| description = "Upload a black & white photo to generate a colorized version. Uses AI model when available, otherwise uses enhanced colorization fallback." |
|
|
| iface = gr.Interface( |
| fn=gradio_colorize, |
| inputs=gr.Image(type="pil", label="Upload B&W Image"), |
| outputs=gr.Image(type="pil", label="Colorized Image"), |
| title=title, |
| description=description, |
| ) |
|
|
| |
| |
| app = gr.mount_gradio_app(app, iface, path="/") |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", "7860")) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|
|
|