text-guided-image-colorization / app /changes_media_clicks.py
LogicGoInfotechSpaces's picture
Update app/changes_media_clicks.py
206be02 verified
#===========================================================
#main.py
#===========================================================
from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Request, Form, Depends, Body
from typing import Optional, Tuple
from fastapi.responses import FileResponse
from huggingface_hub import hf_hub_download
import uuid
import os
import io
import json
import logging
import httpx
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
import numpy as np
from pydantic import BaseModel, EmailStr
from app.database import (
get_database,
log_api_call,
log_image_upload,
log_colorization,
log_media_click,
close_connection,
get_category_id_from_collage_maker,
get_category_id_from_ai_enhancer,
)
try:
import firebase_admin
from firebase_admin import auth as firebase_auth, app_check, credentials
except ImportError:
firebase_admin = None
firebase_auth = None
app_check = None
credentials = None
try:
import boto3
except ImportError:
boto3 = None
# Import CCO colorizers
try:
from app.colorizers import eccv16, siggraph17
from app.colorizers.util import preprocess_img, postprocess_tens
CCO_AVAILABLE = True
except ImportError as e:
print(f"⚠️ CCO colorizers not available: {e}")
CCO_AVAILABLE = False
logger = logging.getLogger(__name__)
# -------------------------------------------------
# 🚀 FastAPI App
# -------------------------------------------------
app = FastAPI(title="Text-Guided Image Colorization API")
# CORS middleware
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------------------------------------
# 🔐 Firebase Initialization (ENV-based)
# -------------------------------------------------
if firebase_admin:
try:
firebase_json = os.getenv("FIREBASE_CREDENTIALS")
if firebase_json:
print("🔥 Loading Firebase credentials from ENV...")
firebase_dict = json.loads(firebase_json)
cred = credentials.Certificate(firebase_dict)
firebase_admin.initialize_app(cred)
else:
print("⚠️ No Firebase credentials found. Firebase disabled.")
except Exception as e:
print("❌ Firebase initialization failed:", e)
else:
print("⚠️ Firebase Admin SDK not available. Firebase features disabled.")
# -------------------------------------------------
# 📁 Directories (FIXED FOR HUGGINGFACE SPACES)
# -------------------------------------------------
UPLOAD_DIR = "/tmp/uploads"
RESULTS_DIR = "/tmp/results"
COMPRESSED_DIR = "/tmp/compressed"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(COMPRESSED_DIR, exist_ok=True)
MEDIA_CLICK_DEFAULT_CATEGORY = os.getenv("DEFAULT_CATEGORY_FALLBACK", "69368fcd2e46bd68ae1889b2")
# -------------------------------------------------
# ☁️ DigitalOcean Spaces Configuration
# -------------------------------------------------
DO_SPACES_KEY = os.getenv("DO_SPACES_KEY")
DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET")
DO_SPACES_REGION = os.getenv("DO_SPACES_REGION")
DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT")
DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET")
DO_SPACES_BASE_FOLDER = os.getenv("DO_SPACES_BASE_FOLDER", "valentine").strip("/")
_spaces_client = None
def _spaces_enabled() -> bool:
return all([DO_SPACES_KEY, DO_SPACES_SECRET, DO_SPACES_REGION, DO_SPACES_ENDPOINT, DO_SPACES_BUCKET]) and boto3 is not None
def _get_spaces_client():
global _spaces_client
if _spaces_client is None:
_spaces_client = boto3.client(
"s3",
region_name=DO_SPACES_REGION,
endpoint_url=DO_SPACES_ENDPOINT,
aws_access_key_id=DO_SPACES_KEY,
aws_secret_access_key=DO_SPACES_SECRET,
)
return _spaces_client
def _build_spaces_public_url(object_key: str) -> str:
parsed = urlparse(DO_SPACES_ENDPOINT if "://" in DO_SPACES_ENDPOINT else f"https://{DO_SPACES_ENDPOINT}")
endpoint_host = parsed.netloc or parsed.path
endpoint_host = endpoint_host.strip("/")
return f"https://{DO_SPACES_BUCKET}.{endpoint_host}/{object_key}"
def upload_bytes_to_spaces(file_bytes: bytes, object_key: str, content_type: str) -> str:
if not _spaces_enabled():
raise RuntimeError("DigitalOcean Spaces is not configured")
client = _get_spaces_client()
client.put_object(
Bucket=DO_SPACES_BUCKET,
Key=object_key,
Body=file_bytes,
ACL="public-read",
ContentType=content_type,
)
return _build_spaces_public_url(object_key)
def upload_file_to_spaces(file_path: str, object_key: str, content_type: str) -> str:
if not _spaces_enabled():
raise RuntimeError("DigitalOcean Spaces is not configured")
client = _get_spaces_client()
with open(file_path, "rb") as f:
client.put_object(
Bucket=DO_SPACES_BUCKET,
Key=object_key,
Body=f,
ACL="public-read",
ContentType=content_type,
)
return _build_spaces_public_url(object_key)
# -------------------------------------------------
# 🧠 Load GAN Colorization Model
# -------------------------------------------------
MODEL_REPO = "Hammad712/GAN-Colorization-Model"
MODEL_FILENAME = "generator.pt"
print("⬇️ Downloading GAN model...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
print("📦 Loading GAN model weights...")
state_dict = torch.load(model_path, map_location="cpu")
# NOTE: Replace with real model architecture
# from model import ColorizeNet
# model = ColorizeNet()
# model.load_state_dict(state_dict)
# model.eval()
# -------------------------------------------------
# 🧠 Load CCO Colorization Models
# -------------------------------------------------
cco_models = {}
if CCO_AVAILABLE:
print("📦 Loading CCO models...")
try:
cco_models["eccv16"] = eccv16(pretrained=True).eval()
cco_models["siggraph17"] = siggraph17(pretrained=True).eval()
print("✅ CCO models loaded successfully!")
except Exception as e:
print(f"⚠️ Failed to load CCO models: {e}")
CCO_AVAILABLE = False
def colorize_image_gan(img: Image.Image):
""" GAN colorizer (dummy implementation - replace with real model.predict) """
transform = transforms.ToTensor()
tensor = transform(img.convert("L")).unsqueeze(0)
tensor = tensor.repeat(1, 3, 1, 1)
output_img = transforms.ToPILImage()(tensor.squeeze())
return output_img
def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"):
""" CCO colorizer using eccv16 or siggraph17 model """
if not CCO_AVAILABLE:
raise ValueError("CCO models are not available")
if model_name not in ["eccv16", "siggraph17"]:
model_name = "eccv16" # Default to eccv16
model = cco_models.get(model_name)
if model is None:
raise ValueError(f"CCO model '{model_name}' not loaded")
# Ensure image is RGB
if img.mode != "RGB":
img = img.convert("RGB")
# Convert PIL Image to numpy array
oimg = np.asarray(img)
if oimg.ndim == 2:
oimg = np.tile(oimg[:,:,None], 3)
# Preprocess image
(tens_l_orig, tens_l_rs) = preprocess_img(oimg)
# Run model inference
with torch.no_grad():
out_ab = model(tens_l_rs)
# Postprocess output (returns RGB in [0, 1] range)
output_rgb = postprocess_tens(tens_l_orig, out_ab)
# Clamp values to [0, 1] and convert to uint8
output_rgb = np.clip(output_rgb, 0, 1)
output_array = (output_rgb * 255).astype(np.uint8)
# Convert numpy array back to PIL Image
output_img = Image.fromarray(output_array, 'RGB')
return output_img
def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"):
"""
Colorize image using specified model
Args:
img: PIL Image to colorize
model_type: "gan" or "cco"
cco_model: "eccv16" or "siggraph17" (only used if model_type is "cco")
Returns:
Colorized PIL Image
"""
if model_type == "cco":
return colorize_image_cco(img, cco_model)
else:
return colorize_image_gan(img)
def compress_image_to_target_size(img: Image.Image, target_size_mb: float = 2.5, min_size_mb: float = 2.0, max_size_mb: float = 3.0) -> Tuple[Image.Image, str]:
"""
Compress image to target file size (2-3MB) by adjusting quality and dimensions
Args:
img: PIL Image to compress
target_size_mb: Target file size in MB (default: 2.5MB)
min_size_mb: Minimum acceptable file size in MB (default: 2.0MB)
max_size_mb: Maximum acceptable file size in MB (default: 3.0MB)
Returns:
Tuple of (compressed PIL Image, path to saved compressed image)
"""
import tempfile
# Ensure image is RGB
if img.mode != "RGB":
img = img.convert("RGB")
# Get original dimensions
original_size = img.size
current_img = img.copy()
# Start with high quality and reduce if needed
quality = 85
scale_factor = 1.0
# Try different quality levels and scales to achieve target size
for attempt in range(20): # Max 20 attempts
# Create temporary file to check size
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
temp_path = tmp_file.name
try:
# Resize if needed
if scale_factor < 1.0:
new_width = int(original_size[0] * scale_factor)
new_height = int(original_size[1] * scale_factor)
resized_img = current_img.resize((new_width, new_height), Image.Resampling.LANCZOS)
else:
resized_img = current_img
# Save with current quality
resized_img.save(temp_path, "JPEG", quality=quality, optimize=True)
# Check file size
file_size_bytes = os.path.getsize(temp_path)
file_size_mb = file_size_bytes / (1024 * 1024)
logger.info("Compression attempt %d: quality=%d, scale=%.2f, size=%.2f MB",
attempt + 1, quality, scale_factor, file_size_mb)
# If size is within target range, we're done
if min_size_mb <= file_size_mb <= max_size_mb:
logger.info("Target size achieved: %.2f MB", file_size_mb)
return resized_img, temp_path
# If too large, reduce quality or scale
if file_size_mb > max_size_mb:
if quality > 30:
quality -= 5 # Reduce quality
elif scale_factor > 0.5:
scale_factor -= 0.05 # Reduce dimensions
quality = 75 # Reset quality when scaling
else:
# Already at minimum, accept current result
logger.warning("Reached minimum compression, file size: %.2f MB", file_size_mb)
return resized_img, temp_path
# If too small but still reasonable (above 1MB), try to increase quality slightly
elif file_size_mb < min_size_mb:
if file_size_mb < 1.0:
# Very small file, increase quality more aggressively
if quality < 90:
quality += 5
elif scale_factor < 1.0:
# Increase scale if we reduced it
scale_factor = min(1.0, scale_factor + 0.1)
elif quality < 95:
# Close to target, fine-tune quality
quality += 2
if quality > 95:
quality = 95
else:
# Already at max quality and scale, accept current result
logger.info("File size %.2f MB is below minimum but at max quality, accepting", file_size_mb)
return resized_img, temp_path
finally:
# Clean up temp file if we're continuing
if os.path.exists(temp_path) and attempt < 19:
try:
os.unlink(temp_path)
except:
pass
# Return the last attempt's result
return resized_img, temp_path
# -------------------------------------------------
# 🗄️ MongoDB Initialization
# -------------------------------------------------
@app.on_event("startup")
async def startup_event():
"""Initialize MongoDB on startup"""
try:
db = get_database()
if db is not None:
print("✅ MongoDB initialized successfully!")
except Exception as e:
print(f"⚠️ MongoDB initialization failed: {e}")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
close_connection()
print("Application shutdown")
# -------------------------------------------------
# 🩺 Health Check
# -------------------------------------------------
@app.get("/health")
def health_check(request: Request):
response = {
"status": "healthy",
"model_loaded": True,
"model_type": "hf_inference_api",
"provider": "fal-ai"
}
# Log API call
log_api_call(
endpoint="/health",
method="GET",
status_code=200,
response_data=response,
ip_address=request.client.host if request.client else None
)
return response
# -------------------------------------------------
# 🔐 Auth Models
# -------------------------------------------------
class RegisterRequest(BaseModel):
email: EmailStr
password: str
display_name: Optional[str] = None
class LoginRequest(BaseModel):
email: EmailStr
password: str
class TokenResponse(BaseModel):
id_token: str
refresh_token: Optional[str] = None
expires_in: int
token_type: str = "Bearer"
user: dict
# -------------------------------------------------
# 🔐 Firebase Token Validator & Auth Helpers
# -------------------------------------------------
def _extract_bearer_token(authorization_header: Optional[str]) -> Optional[str]:
if not authorization_header:
return None
parts = authorization_header.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1].strip()
return None
async def verify_request(request: Request):
"""
Verify Firebase authentication.
Priority:
1. Firebase App Check token (X-Firebase-AppCheck header) - Primary method
2. Firebase Auth ID token (Authorization: Bearer header) - Fallback for auth endpoints
"""
if not firebase_admin:
return True
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps:
return True
# Primary: Check Firebase App Check token (X-Firebase-AppCheck header)
app_check_token = request.headers.get("X-Firebase-AppCheck") or request.headers.get("x-firebase-appcheck")
if app_check_token and app_check:
try:
app_check_claims = app_check.verify_token(app_check_token)
logger.info("App Check token verified for: %s", app_check_claims.get("app_id"))
return True
except Exception as e:
logger.warning("App Check token verification failed: %s", str(e))
# Secondary: Check Firebase Auth ID token (Authorization: Bearer header)
bearer = _extract_bearer_token(request.headers.get("Authorization"))
if bearer and firebase_auth:
try:
decoded = firebase_auth.verify_id_token(bearer)
request.state.user = decoded
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
return True
except Exception as e:
logger.warning("Auth token verification failed: %s", str(e))
# If no valid token, allow (for public endpoints)
return True
from firebase_admin import app_check
from fastapi import HTTPException
import os
def verify_app_check_token(
token: str | None,
*,
required: bool = False
):
"""
If required=False:
- Missing token is allowed
- Invalid token is logged as warning but request is allowed
If required=True:
- Missing OR invalid token is rejected
"""
# Token missing
if not token:
if required:
raise HTTPException(
status_code=401,
detail="Firebase App Check token missing"
)
return True # OPTIONAL → allow request
# Token present → verify it
try:
if not app_check:
# Firebase App Check not available, log warning but allow if not required
if required:
raise HTTPException(
status_code=503,
detail="Firebase App Check not configured"
)
logger.warning("Firebase App Check not available, but token was provided. Allowing request.")
return True
# Verify the token
app_check.verify_token(token)
logger.debug("Firebase App Check token verified successfully")
return True
except HTTPException:
# Re-raise HTTPExceptions (these are intentional)
raise
except Exception as e:
# For optional tokens, log warning but allow request
if not required:
logger.warning("Invalid Firebase App Check token provided, but allowing request (optional): %s", str(e))
return True
# For required tokens, reject invalid tokens
logger.error("Invalid Firebase App Check token (required): %s", str(e))
raise HTTPException(
status_code=401,
detail="Invalid Firebase App Check token"
)
def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]:
"""Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click)."""
if supplied_user_id and supplied_user_id.strip():
return supplied_user_id.strip()
return None
# -------------------------------------------------
# 🔐 Auth Endpoints
# -------------------------------------------------
@app.post("/auth/register", response_model=TokenResponse)
async def register_user(user_data: RegisterRequest):
"""
Register a new user with email and password.
Returns Firebase ID token for immediate use.
"""
if not firebase_admin:
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available")
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
if not firebase_auth:
raise HTTPException(status_code=503, detail="Firebase Auth not available")
try:
# Create user using Firebase Admin SDK
user_record = firebase_auth.create_user(
email=user_data.email,
password=user_data.password,
display_name=user_data.display_name,
email_verified=False
)
# Generate custom token that client can exchange for ID token
custom_token = firebase_auth.create_custom_token(user_record.uid)
logger.info("User registered: %s (uid: %s)", user_data.email, user_record.uid)
return TokenResponse(
id_token=custom_token.decode('utf-8') if isinstance(custom_token, bytes) else custom_token,
token_type="Bearer",
expires_in=3600,
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name or "",
"email_verified": user_record.email_verified
}
)
except Exception as e:
error_msg = str(e)
if "already exists" in error_msg.lower() or "email" in error_msg.lower():
raise HTTPException(status_code=400, detail="Email already registered")
logger.error("Registration error: %s", error_msg)
raise HTTPException(status_code=500, detail=f"Registration failed: {error_msg}")
@app.post("/auth/login", response_model=TokenResponse)
async def login_user(credentials: LoginRequest):
"""
Login with email and password.
Uses Firebase REST API to authenticate and get ID token.
"""
if not firebase_admin:
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available")
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
if not firebase_auth:
raise HTTPException(status_code=503, detail="Firebase Auth not available")
# Firebase REST API endpoint for email/password authentication
from app.config import settings
firebase_api_key = os.getenv("FIREBASE_API_KEY") or getattr(settings, 'FIREBASE_API_KEY', '')
if not firebase_api_key:
# Fallback: verify user exists and return custom token
try:
user_record = firebase_auth.get_user_by_email(credentials.email)
custom_token = firebase_auth.create_custom_token(user_record.uid)
logger.info("User login: %s (uid: %s)", credentials.email, user_record.uid)
return TokenResponse(
id_token=custom_token.decode('utf-8') if isinstance(custom_token, bytes) else custom_token,
token_type="Bearer",
expires_in=3600,
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name or "",
"email_verified": user_record.email_verified
}
)
except Exception as e:
error_msg = str(e)
if "not found" in error_msg.lower():
raise HTTPException(status_code=401, detail="Invalid email or password")
logger.error("Login error: %s", error_msg)
raise HTTPException(status_code=500, detail=f"Login failed: {error_msg}")
# Use Firebase REST API for proper authentication
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={firebase_api_key}",
json={
"email": credentials.email,
"password": credentials.password,
"returnSecureToken": True
}
)
if response.status_code != 200:
error_data = response.json()
error_msg = error_data.get("error", {}).get("message", "Authentication failed")
raise HTTPException(status_code=401, detail=error_msg)
data = response.json()
logger.info("User login successful: %s", credentials.email)
# Get user details from Admin SDK
user_record = firebase_auth.get_user(data["localId"])
return TokenResponse(
id_token=data["idToken"],
refresh_token=data.get("refreshToken"),
expires_in=int(data.get("expiresIn", 3600)),
token_type="Bearer",
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name or "",
"email_verified": user_record.email_verified
}
)
except httpx.HTTPError as e:
logger.error("HTTP error during login: %s", str(e))
raise HTTPException(status_code=500, detail="Authentication service unavailable")
except HTTPException:
raise
except Exception as e:
logger.error("Login error: %s", str(e))
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}")
@app.get("/auth/me")
async def get_current_user(request: Request, verified: bool = Depends(verify_request)):
"""Get current authenticated user information"""
if not firebase_admin:
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available")
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
if not firebase_auth:
raise HTTPException(status_code=503, detail="Firebase Auth not available")
# Get user from request state (set by verify_request)
if hasattr(request, 'state') and hasattr(request.state, 'user'):
user_data = request.state.user
uid = user_data.get("uid")
try:
user_record = firebase_auth.get_user(uid)
return {
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name or "",
"email_verified": user_record.email_verified,
}
except Exception as e:
logger.error("Error getting user: %s", str(e))
raise HTTPException(status_code=404, detail="User not found")
raise HTTPException(status_code=401, detail="Not authenticated")
@app.post("/auth/refresh")
async def refresh_token(refresh_token_param: str = Body(..., embed=True)):
"""Refresh Firebase ID token using refresh token"""
from app.config import settings
firebase_api_key = os.getenv("FIREBASE_API_KEY") or getattr(settings, 'FIREBASE_API_KEY', '')
if not firebase_api_key:
raise HTTPException(status_code=503, detail="Firebase API key not configured")
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"https://securetoken.googleapis.com/v1/token?key={firebase_api_key}",
json={
"grant_type": "refresh_token",
"refresh_token": refresh_token_param
}
)
if response.status_code != 200:
error_data = response.json()
error_msg = error_data.get("error", {}).get("message", "Token refresh failed")
raise HTTPException(status_code=401, detail=error_msg)
data = response.json()
return {
"id_token": data["id_token"],
"refresh_token": data.get("refresh_token"),
"expires_in": int(data.get("expires_in", 3600)),
"token_type": "Bearer"
}
except httpx.HTTPError as e:
logger.error("HTTP error during token refresh: %s", str(e))
raise HTTPException(status_code=500, detail="Token refresh service unavailable")
except HTTPException:
raise
except Exception as e:
logger.error("Token refresh error: %s", str(e))
raise HTTPException(status_code=500, detail=f"Token refresh failed: {str(e)}")
# -------------------------------------------------
# 📤 Upload Image
# -------------------------------------------------
@app.post("/upload")
async def upload_image(
request: Request,
file: UploadFile = File(...),
x_firebase_appcheck: str = Header(None),
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
):
verify_app_check_token(x_firebase_appcheck)
ip_address = request.client.host if request.client else None
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
if not file.content_type.startswith("image/"):
log_api_call(
endpoint="/upload",
method="POST",
status_code=400,
error="Invalid file type",
ip_address=ip_address
)
raise HTTPException(status_code=400, detail="Invalid file type")
image_id = f"{uuid.uuid4()}.jpg"
file_path = os.path.join(UPLOAD_DIR, image_id)
img_bytes = await file.read()
file_size = len(img_bytes)
with open(file_path, "wb") as f:
f.write(img_bytes)
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
response_data = {
"success": True,
"image_id": image_id.replace(".jpg", ""),
"file_url": f"{base_url}/uploads/{image_id}"
}
# Log to MongoDB
log_image_upload(
image_id=image_id.replace(".jpg", ""),
filename=file.filename or image_id,
file_size=file_size,
content_type=file.content_type or "image/jpeg",
user_id=effective_user_id,
ip_address=ip_address
)
log_api_call(
endpoint="/upload",
method="POST",
status_code=200,
request_data={"filename": file.filename, "content_type": file.content_type},
response_data=response_data,
user_id=effective_user_id,
ip_address=ip_address
)
return response_data
# -------------------------------------------------
# 🎨 Colorize Image
# -------------------------------------------------
@app.post("/colorize")
async def colorize(
request: Request,
file: UploadFile = File(...),
x_firebase_appcheck: str = Header(None),
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
model: Optional[str] = Form(None), # Model parameter: "gan", "cco", "cco-eccv16", "cco-siggraph17" (default: CCO if available)
appname: Optional[str] = Form(None), # Optional app name (e.g., "collage-maker", "ai-enhancer")
):
import time
start_time = time.time()
verify_app_check_token(x_firebase_appcheck,required=False)
ip_address = request.client.host if request.client else None
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
# Handle appname parameter - if appname is "collage-maker" or "ai-enhancer", get category ID from respective database
effective_appname = appname.strip() if appname and isinstance(appname, str) and appname.strip() else None
if effective_appname and effective_appname.lower() == "collage-maker":
# Get category ID from collage-maker database
collage_maker_category_id = get_category_id_from_collage_maker()
if collage_maker_category_id:
# Use the category ID from collage-maker if no category_id was provided
if not effective_category_id:
effective_category_id = collage_maker_category_id
logger.info("Using category ID from collage-maker database: %s", effective_category_id)
else:
logger.warning("appname is 'collage-maker' but could not fetch category ID from collage-maker database")
elif effective_appname and effective_appname.lower() == "ai-enhancer":
# Get category ID from AI-enhancer database
ai_enhancer_category_id = get_category_id_from_ai_enhancer()
if ai_enhancer_category_id:
# Use the category ID from AI-enhancer if no category_id was provided
if not effective_category_id:
effective_category_id = ai_enhancer_category_id
logger.info("Using category ID from AI-enhancer database: %s", effective_category_id)
else:
logger.warning("appname is 'ai-enhancer' but could not fetch category ID from AI-enhancer database")
elif effective_appname:
logger.info("appname provided: %s (not 'collage-maker' or 'ai-enhancer', skipping category lookup)", effective_appname)
# Parse model parameter
# Default to CCO if available, otherwise fallback to GAN
if CCO_AVAILABLE:
model_type = "cco"
cco_model = "eccv16"
model_type_for_log = "cco-eccv16"
else:
model_type = "gan"
model_type_for_log = "gan"
if model:
model = model.strip().lower()
if model == "gan":
# Use GAN model (dummy implementation - doesn't actually colorize)
model_type = "gan"
model_type_for_log = "gan"
elif model == "cco" or model.startswith("cco-"):
if not CCO_AVAILABLE:
error_msg = "CCO models are not available"
log_api_call(
endpoint="/colorize",
method="POST",
status_code=400,
error=error_msg,
ip_address=ip_address
)
log_colorization(
result_id=None,
model_type="cco",
processing_time=(time.time() - start_time),
user_id=effective_user_id,
ip_address=ip_address,
status="fail",
error=error_msg,
appname=effective_appname
)
raise HTTPException(status_code=400, detail=error_msg)
model_type = "cco"
if model == "cco-eccv16":
cco_model = "eccv16"
model_type_for_log = "cco-eccv16"
elif model == "cco-siggraph17":
cco_model = "siggraph17"
model_type_for_log = "cco-siggraph17"
else:
# Default to eccv16 if just "cco" is specified
cco_model = "eccv16"
model_type_for_log = "cco-eccv16"
else:
# Unknown model, use default (CCO if available, else GAN)
pass
if not file.content_type.startswith("image/"):
error_msg = "Invalid file type"
log_api_call(
endpoint="/colorize",
method="POST",
status_code=400,
error=error_msg,
ip_address=ip_address
)
# Log failed colorization
log_colorization(
result_id=None,
model_type=model_type_for_log,
processing_time=(time.time() - start_time),
user_id=effective_user_id,
ip_address=ip_address,
status="fail",
error=error_msg,
appname=effective_appname
)
raise HTTPException(status_code=400, detail=error_msg)
try:
input_bytes = await file.read()
img = Image.open(io.BytesIO(input_bytes))
# Ensure image is RGB
if img.mode != "RGB":
img = img.convert("RGB")
output_img = colorize_image(img, model_type=model_type, cco_model=cco_model)
# Ensure output is RGB
if output_img.mode != "RGB":
output_img = output_img.convert("RGB")
processing_time = time.time() - start_time
result_id = f"{uuid.uuid4()}.png"
output_path = os.path.join(RESULTS_DIR, result_id)
# Save original image as PNG (uncompressed) - this is what the model produces
output_img.save(output_path, "PNG")
# Create compressed version targeting 2-3MB file size
logger.info("Creating compressed version targeting 2-3MB file size...")
compressed_img, temp_compressed_path = compress_image_to_target_size(
output_img,
target_size_mb=2.5,
min_size_mb=2.0,
max_size_mb=3.0
)
# Move compressed image to final location
compressed_filename = result_id.replace(".png", "_compressed.jpg")
compressed_path = os.path.join(COMPRESSED_DIR, compressed_filename)
# If temp file exists, move it; otherwise save the compressed image
if os.path.exists(temp_compressed_path):
import shutil
shutil.move(temp_compressed_path, compressed_path)
else:
compressed_img.save(compressed_path, "JPEG", quality=75, optimize=True)
# Log compressed file size
compressed_size_mb = os.path.getsize(compressed_path) / (1024 * 1024)
original_size_mb = os.path.getsize(output_path) / (1024 * 1024)
logger.info("Original image size: %.2f MB, Compressed image size: %.2f MB",
original_size_mb, compressed_size_mb)
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
result_id_clean = result_id.replace(".png", "")
# Use a consistent caption text for all models
caption = "colorize this image with vibrant natural colors, high quality"
compressed_image_url = f"{base_url}/compressed/{compressed_filename}"
# Upload source image + compressed result image to DigitalOcean Spaces
if _spaces_enabled():
try:
source_ext = ".jpg"
if file.filename and "." in file.filename:
source_ext = os.path.splitext(file.filename)[1].lower() or ".jpg"
source_object_key = f"{DO_SPACES_BASE_FOLDER}/source/{result_id_clean}{source_ext}"
result_object_key = f"{DO_SPACES_BASE_FOLDER}/results/{compressed_filename}"
upload_bytes_to_spaces(input_bytes, source_object_key, file.content_type or "image/jpeg")
compressed_image_url = upload_file_to_spaces(compressed_path, result_object_key, "image/jpeg")
logger.info(
"Uploaded source and compressed result to DO Spaces: %s, %s",
source_object_key,
result_object_key
)
except Exception as spaces_error:
logger.warning("Failed to upload to DO Spaces, using local compressed URL: %s", str(spaces_error))
else:
logger.info("DO Spaces env not fully configured or boto3 missing; using local compressed URL")
response_data = {
"success": True,
"result_id": result_id_clean,
"download_url": f"{base_url}/results/{result_id}",
"api_download_url": f"{base_url}/download/{result_id_clean}",
"filename": result_id,
"caption": caption,
"Compressed_Image_URL": compressed_image_url
}
# Log to MongoDB (colorization_db -> colorizations)
log_colorization(
result_id=result_id_clean,
model_type=model_type_for_log,
processing_time=processing_time,
user_id=effective_user_id,
ip_address=ip_address,
status="success",
appname=effective_appname
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=200,
request_data={"filename": file.filename, "content_type": file.content_type, "model": model},
response_data=response_data,
user_id=effective_user_id,
ip_address=ip_address
)
return response_data
except Exception as e:
error_msg = str(e)
logger.error("Error colorizing image: %s", error_msg)
# Log failed colorization to colorizations collection
log_colorization(
result_id=None,
model_type=model_type_for_log,
processing_time=(time.time() - start_time),
user_id=effective_user_id,
ip_address=ip_address,
status="fail",
error=error_msg,
appname=effective_appname
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=500,
error=error_msg,
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}")
# -------------------------------------------------
# ⬇️ Download via API (Secure)
# -------------------------------------------------
@app.get("/download/{file_id}")
def download_result(
request: Request,
file_id: str,
x_firebase_appcheck: str = Header(None)
):
verify_app_check_token(x_firebase_appcheck)
ip_address = request.client.host if request.client else None
# Try PNG first, then JPG for backward compatibility
filename_png = f"{file_id}.png"
filename_jpg = f"{file_id}.jpg"
path_png = os.path.join(RESULTS_DIR, filename_png)
path_jpg = os.path.join(RESULTS_DIR, filename_jpg)
# Check which file exists
if os.path.exists(path_png):
filename = filename_png
path = path_png
media_type = "image/png"
elif os.path.exists(path_jpg):
filename = filename_jpg
path = path_jpg
media_type = "image/jpeg"
else:
filename = filename_png
path = path_png
media_type = "image/png"
if not os.path.exists(path):
log_api_call(
endpoint=f"/download/{file_id}",
method="GET",
status_code=404,
error="Result not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Result not found")
log_api_call(
endpoint=f"/download/{file_id}",
method="GET",
status_code=200,
request_data={"file_id": file_id},
ip_address=ip_address
)
return FileResponse(path, media_type=media_type)
# -------------------------------------------------
# 🌐 Public Result File
# -------------------------------------------------
@app.get("/results/{filename}")
def get_result(request: Request, filename: str):
ip_address = request.client.host if request.client else None
path = os.path.join(RESULTS_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/results/{filename}",
method="GET",
status_code=404,
error="Result not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Result not found")
# Determine media type based on file extension
if filename.lower().endswith('.png'):
media_type = "image/png"
elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'):
media_type = "image/jpeg"
else:
media_type = "image/png" # Default to PNG
log_api_call(
endpoint=f"/results/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type=media_type)
# -------------------------------------------------
# 🌐 Public Uploaded File
# -------------------------------------------------
@app.get("/uploads/{filename}")
def get_upload(request: Request, filename: str):
ip_address = request.client.host if request.client else None
path = os.path.join(UPLOAD_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/uploads/{filename}",
method="GET",
status_code=404,
error="File not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="File not found")
# Determine media type based on file extension
if filename.lower().endswith('.png'):
media_type = "image/png"
elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'):
media_type = "image/jpeg"
else:
media_type = "image/jpeg" # Default to JPEG
log_api_call(
endpoint=f"/uploads/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type=media_type)
# -------------------------------------------------
# 🌐 Public Compressed File
# -------------------------------------------------
@app.get("/compressed/{filename}")
def get_compressed(request: Request, filename: str):
ip_address = request.client.host if request.client else None
path = os.path.join(COMPRESSED_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/compressed/{filename}",
method="GET",
status_code=404,
error="Compressed file not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Compressed file not found")
# Compressed images are JPEG
media_type = "image/jpeg"
log_api_call(
endpoint=f"/compressed/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type=media_type)
@app.get("/")
async def root():
"""Root endpoint"""
return {
"success": True,
"message": "Image Colorization API",
"data": {
"version": "1.0.0",
"Product Name":"Beauty Camera - GlowCam AI Studio",
"Released By" : "LogicGo Infotech"
}
}
#=================================================================================
#main_sdxt.py
#=================================================================================
"""
FastAPI application for Text-Guided Image Colorization using Hugging Face Inference API
Uses fal-ai provider for memory-efficient inference
"""
import os
import io
import uuid
import logging
from pathlib import Path
from typing import Optional, Tuple
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request, Body, Form
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import firebase_admin
from firebase_admin import credentials, app_check, auth as firebase_auth
from PIL import Image
import uvicorn
import gradio as gr
import httpx
from pydantic import BaseModel, EmailStr
# Hugging Face Inference API
from huggingface_hub import InferenceClient
from app.config import settings
from app.database import (
get_database,
log_api_call,
log_image_upload,
log_colorization,
log_media_click,
close_connection,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Create writable directories
Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True)
Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True)
Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True)
Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True)
# Initialize FastAPI app
app = FastAPI(
title="Text-Guided Image Colorization API",
description="Image colorization using SDXL + ControlNet with automatic captioning",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Firebase Admin SDK
# Try multiple possible paths for Firebase credentials
firebase_cred_paths = [
os.getenv("FIREBASE_CREDENTIALS_PATH"),
"/tmp/firebase-adminsdk.json",
"/data/firebase-adminsdk.json",
"colorize-662df-firebase-adminsdk-fbsvc-bfd21c77c6.json",
os.path.join(os.path.dirname(__file__), "..", "colorize-662df-firebase-adminsdk-fbsvc-bfd21c77c6.json"),
]
firebase_initialized = False
for cred_path in firebase_cred_paths:
if not cred_path:
continue
cred_path = os.path.abspath(cred_path)
if os.path.exists(cred_path):
try:
cred = credentials.Certificate(cred_path)
firebase_admin.initialize_app(cred)
logger.info("Firebase Admin SDK initialized from: %s", cred_path)
firebase_initialized = True
break
except Exception as e:
logger.warning("Failed to initialize Firebase from %s: %s", cred_path, str(e))
continue
# Also try loading from environment variable (for Hugging Face Spaces)
if not firebase_initialized:
firebase_json = os.getenv("FIREBASE_CREDENTIALS")
if firebase_json:
try:
import json
firebase_dict = json.loads(firebase_json)
cred = credentials.Certificate(firebase_dict)
firebase_admin.initialize_app(cred)
logger.info("Firebase Admin SDK initialized from environment variable")
firebase_initialized = True
except Exception as e:
logger.warning("Failed to initialize Firebase from environment: %s", str(e))
if not firebase_initialized:
logger.warning("Firebase credentials file not found. App Check will be disabled.")
try:
firebase_admin.initialize_app()
except:
pass
# Storage directories
UPLOAD_DIR = Path("/tmp/colorize_uploads")
RESULT_DIR = Path("/tmp/colorize_results")
# Mount static files
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
# Global Inference API client
inference_client = None
model_load_error: Optional[str] = None
# ========== Utility Functions ==========
def apply_color(image: Image.Image, color_map: Image.Image) -> Image.Image:
"""Apply color from color_map to image using LAB color space."""
# Convert to LAB color space
image_lab = image.convert('LAB')
color_map_lab = color_map.convert('LAB')
# Extract and merge LAB channels
l, _, _ = image_lab.split()
_, a_map, b_map = color_map_lab.split()
merged_lab = Image.merge('LAB', (l, a_map, b_map))
return merged_lab.convert('RGB')
def remove_unlikely_words(prompt: str) -> str:
"""Removes predefined unlikely phrases from prompt text."""
unlikely_words = []
a1 = [f'{i}s' for i in range(1900, 2000)]
a2 = [f'{i}' for i in range(1900, 2000)]
a3 = [f'year {i}' for i in range(1900, 2000)]
a4 = [f'circa {i}' for i in range(1900, 2000)]
b1 = [f"{y[0]} {y[1]} {y[2]} {y[3]} s" for y in a1]
b2 = [f"{y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
b3 = [f"year {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
b4 = [f"circa {y[0]} {y[1]} {y[2]} {y[3]}" for y in a1]
manual = [
"black and white,", "black and white", "black & white,", "black & white", "circa",
"balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
"black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
"grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
"back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
"grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
"grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
"b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
"black-and-white photo,", "black-and-white photo", "black - and - white photography",
"b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
"blurry photo,", "blurry,", "blurry photography,", "monochromatic photo",
"black - and - white photograph,", "black - and - white photograph", "black on white,",
"black on white", "black-and-white", "historical image,", "historical picture,",
"historical photo,", "historical photograph,", "archival photo,", "taken in the early",
"taken in the late", "taken in the", "historic photograph,", "restored,", "restored",
"historical photo", "historical setting,",
"historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
"taken in", "shot on leica", "shot on leica sl2", "sl2",
"taken with a leica camera", "leica sl2", "leica", "setting",
"overcast day", "overcast weather", "slight overcast", "overcast",
"picture taken in", "photo taken in",
", photo", ", photo", ", photo", ", photo", ", photograph",
",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
]
unlikely_words.extend(a1 + a2 + a3 + a4 + b1 + b2 + b3 + b4 + manual)
for word in unlikely_words:
prompt = prompt.replace(word, "")
return prompt
# ========== Model Loading ==========
@app.on_event("startup")
async def startup_event():
"""Initialize Hugging Face Inference API client and MongoDB"""
global inference_client, model_load_error
# Initialize MongoDB
try:
db = get_database()
if db is not None:
logger.info("✅ MongoDB initialized successfully!")
except Exception as e:
logger.warning("⚠️ MongoDB initialization failed: %s", str(e))
try:
logger.info("🔄 Initializing Hugging Face Inference API client...")
# Get HF token from environment or settings
hf_token = os.getenv("HF_TOKEN") or settings.HF_TOKEN
if not hf_token:
raise ValueError("HF_TOKEN environment variable is required for Inference API")
# Initialize InferenceClient with fal-ai provider
inference_client = InferenceClient(
provider="fal-ai",
api_key=hf_token,
)
logger.info("✅ Inference API client initialized successfully!")
model_load_error = None
except Exception as e:
error_msg = str(e)
logger.error(f"❌ Failed to initialize Inference API client: {error_msg}")
model_load_error = error_msg
# Don't raise - allow health check to work
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
global inference_client
if inference_client:
inference_client = None
close_connection()
logger.info("Application shutdown")
# ========== Authentication Models ==========
class RegisterRequest(BaseModel):
email: EmailStr
password: str
display_name: Optional[str] = None
class LoginRequest(BaseModel):
email: EmailStr
password: str
class TokenResponse(BaseModel):
id_token: str
refresh_token: Optional[str] = None
expires_in: int
token_type: str = "Bearer"
user: dict
# ========== Authentication ==========
def _extract_bearer_token(authorization_header: str | None) -> str | None:
if not authorization_header:
return None
parts = authorization_header.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1].strip()
return None
async def verify_request(request: Request):
"""
Verify Firebase authentication.
Priority:
1. Firebase App Check token (X-Firebase-AppCheck header) - Primary method per documentation
2. Firebase Auth ID token (Authorization: Bearer header) - Fallback for auth endpoints
"""
if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true":
return True
# Primary: Check Firebase App Check token (X-Firebase-AppCheck header)
app_check_token = request.headers.get("X-Firebase-AppCheck")
if app_check_token:
try:
app_check_claims = app_check.verify_token(app_check_token)
logger.info("App Check token verified for: %s", app_check_claims.get("app_id"))
return True
except Exception as e:
logger.warning("App Check token verification failed: %s", str(e))
if settings.ENABLE_APP_CHECK:
raise HTTPException(status_code=401, detail="Invalid App Check token")
# Secondary: Check Firebase Auth ID token (Authorization: Bearer header)
# This is for /auth/* endpoints that use email/password login
bearer = _extract_bearer_token(request.headers.get("Authorization"))
if bearer:
try:
decoded = firebase_auth.verify_id_token(bearer)
request.state.user = decoded
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
return True
except Exception as e:
logger.warning("Auth token verification failed: %s", str(e))
# If App Check is enabled and no valid token provided, require it
if settings.ENABLE_APP_CHECK:
if not app_check_token:
raise HTTPException(status_code=401, detail="Missing App Check token")
raise HTTPException(status_code=401, detail="Invalid App Check token")
# If auth is disabled, allow access
return True
def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]:
"""Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click)."""
if supplied_user_id and supplied_user_id.strip():
return supplied_user_id.strip()
return None
# ========== Auth Endpoints ==========
@app.post("/auth/register", response_model=TokenResponse)
async def register_user(user_data: RegisterRequest):
"""
Register a new user with email and password.
Returns Firebase ID token for immediate use.
"""
if not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
try:
# Create user using Firebase Admin SDK
user_record = firebase_auth.create_user(
email=user_data.email,
password=user_data.password,
display_name=user_data.display_name,
email_verified=False
)
# Generate custom token that client can exchange for ID token
custom_token = firebase_auth.create_custom_token(user_record.uid)
logger.info("User registered: %s (uid: %s)", user_data.email, user_record.uid)
return TokenResponse(
id_token=custom_token.decode('utf-8'), # Custom token (client should exchange)
token_type="Bearer",
expires_in=3600,
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name,
"email_verified": user_record.email_verified
}
)
except firebase_auth.EmailAlreadyExistsError:
raise HTTPException(status_code=400, detail="Email already registered")
except ValueError as e:
raise HTTPException(status_code=400, detail=f"Invalid input: {str(e)}")
except Exception as e:
logger.error("Registration error: %s", str(e))
raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}")
@app.post("/auth/login", response_model=TokenResponse)
async def login_user(credentials: LoginRequest):
"""
Login with email and password.
Uses Firebase REST API to authenticate and get ID token.
"""
if not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
# Firebase REST API endpoint for email/password authentication
firebase_api_key = os.getenv("FIREBASE_API_KEY") or settings.FIREBASE_API_KEY
if not firebase_api_key:
# Fallback: verify user exists and return custom token
try:
user_record = firebase_auth.get_user_by_email(credentials.email)
custom_token = firebase_auth.create_custom_token(user_record.uid)
logger.info("User login: %s (uid: %s)", credentials.email, user_record.uid)
return TokenResponse(
id_token=custom_token.decode('utf-8'),
token_type="Bearer",
expires_in=3600,
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name,
"email_verified": user_record.email_verified
}
)
except firebase_auth.UserNotFoundError:
raise HTTPException(status_code=401, detail="Invalid email or password")
except Exception as e:
logger.error("Login error: %s", str(e))
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}")
# Use Firebase REST API for proper authentication
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={firebase_api_key}",
json={
"email": credentials.email,
"password": credentials.password,
"returnSecureToken": True
}
)
if response.status_code != 200:
error_data = response.json()
error_msg = error_data.get("error", {}).get("message", "Authentication failed")
raise HTTPException(status_code=401, detail=error_msg)
data = response.json()
logger.info("User login successful: %s", credentials.email)
# Get user details from Admin SDK
user_record = firebase_auth.get_user(data["localId"])
return TokenResponse(
id_token=data["idToken"],
refresh_token=data.get("refreshToken"),
expires_in=int(data.get("expiresIn", 3600)),
token_type="Bearer",
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name,
"email_verified": user_record.email_verified
}
)
except httpx.HTTPError as e:
logger.error("HTTP error during login: %s", str(e))
raise HTTPException(status_code=500, detail="Authentication service unavailable")
except Exception as e:
logger.error("Login error: %s", str(e))
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}")
@app.get("/auth/me")
async def get_current_user(request: Request, verified: bool = Depends(verify_request)):
"""Get current authenticated user information"""
if not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
# Get user from request state (set by verify_request)
if hasattr(request, 'state') and hasattr(request.state, 'user'):
user_data = request.state.user
uid = user_data.get("uid")
try:
user_record = firebase_auth.get_user(uid)
return {
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name,
"email_verified": user_record.email_verified,
"created_at": user_record.user_metadata.creation_timestamp,
}
except Exception as e:
logger.error("Error getting user: %s", str(e))
raise HTTPException(status_code=404, detail="User not found")
raise HTTPException(status_code=401, detail="Not authenticated")
@app.post("/auth/refresh")
async def refresh_token(refresh_token: str = Body(..., embed=True)):
"""Refresh Firebase ID token using refresh token"""
firebase_api_key = os.getenv("FIREBASE_API_KEY") or settings.FIREBASE_API_KEY
if not firebase_api_key:
raise HTTPException(status_code=503, detail="Firebase API key not configured")
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"https://securetoken.googleapis.com/v1/token?key={firebase_api_key}",
json={
"grant_type": "refresh_token",
"refresh_token": refresh_token
}
)
if response.status_code != 200:
error_data = response.json()
error_msg = error_data.get("error", {}).get("message", "Token refresh failed")
raise HTTPException(status_code=401, detail=error_msg)
data = response.json()
return {
"id_token": data["id_token"],
"refresh_token": data.get("refresh_token"),
"expires_in": int(data.get("expires_in", 3600)),
"token_type": "Bearer"
}
except httpx.HTTPError as e:
logger.error("HTTP error during token refresh: %s", str(e))
raise HTTPException(status_code=500, detail="Token refresh service unavailable")
except Exception as e:
logger.error("Token refresh error: %s", str(e))
raise HTTPException(status_code=500, detail=f"Token refresh failed: {str(e)}")
# ========== API Endpoints ==========
@app.get("/api")
async def api_info(request: Request):
"""API info endpoint"""
response_data = {
"app": "Text-Guided Image Colorization API",
"version": "1.0.0",
"endpoints": {
"health": "/health",
"upload": "/upload",
"colorize": "/colorize",
"download": "/download/{file_id}",
"results": "/results/{filename}",
"uploads": "/uploads/{filename}",
"auth": {
"register": "/auth/register",
"login": "/auth/login",
"me": "/auth/me",
"refresh": "/auth/refresh"
},
"gradio": "/"
}
}
# Log API call
user_id = None
if hasattr(request, 'state') and hasattr(request.state, 'user'):
user_id = request.state.user.get("uid")
log_api_call(
endpoint="/api",
method="GET",
status_code=200,
response_data=response_data,
user_id=user_id,
ip_address=request.client.host if request.client else None
)
return response_data
@app.get("/health")
async def health_check(request: Request):
"""Health check endpoint"""
response = {
"status": "healthy",
"model_loaded": inference_client is not None,
"model_type": "hf_inference_api",
"provider": "fal-ai"
}
if model_load_error:
response["model_error"] = model_load_error
# Log API call
log_api_call(
endpoint="/health",
method="GET",
status_code=200,
response_data=response,
ip_address=request.client.host if request.client else None
)
return response
def colorize_image_sdxl(
image: Image.Image,
positive_prompt: Optional[str] = None,
negative_prompt: Optional[str] = None,
seed: int = 123,
num_inference_steps: int = 8
) -> Tuple[Image.Image, str]:
"""
Colorize a grayscale or low-color image using Hugging Face Inference API.
Args:
image: PIL Image to colorize
positive_prompt: Additional descriptive text to enhance the caption
negative_prompt: Words or phrases to avoid during generation
seed: Random seed for reproducible generation
num_inference_steps: Number of inference steps
Returns:
Tuple of (colorized PIL Image, caption string)
"""
if inference_client is None:
raise RuntimeError("Inference API client not initialized")
original_size = image.size
# Resize to 512x512 for inference (FLUX models work well at this size)
control_image = image.convert("RGB").resize((512, 512))
# Convert image to bytes for API
img_bytes = io.BytesIO()
control_image.save(img_bytes, format="PNG")
img_bytes.seek(0)
input_image = img_bytes.read()
# Construct prompt
base_prompt = positive_prompt or "colorize this image with vibrant natural colors, high quality"
if negative_prompt:
# Note: Some models may not support negative_prompt directly
final_prompt = f"{base_prompt}. Avoid: {negative_prompt}"
else:
final_prompt = base_prompt
# Use Inference API for image-to-image generation
model_name = settings.INFERENCE_MODEL
logger.info(f"Calling Inference API with model {model_name}, prompt: {final_prompt}")
try:
result_image = inference_client.image_to_image(
input_image,
prompt=final_prompt,
model=model_name,
)
# Resize back to original size
if isinstance(result_image, Image.Image):
colorized = result_image.resize(original_size)
else:
# If it's bytes, convert to PIL Image
colorized = Image.open(io.BytesIO(result_image)).resize(original_size)
# Generate a simple caption from the prompt
caption = final_prompt[:100] # Truncate for display
return colorized, caption
except Exception as e:
logger.error(f"Inference API error: {e}")
raise RuntimeError(f"Failed to colorize image: {str(e)}")
@app.post("/upload")
async def upload_image(
request: Request,
file: UploadFile = File(...),
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
verified: bool = Depends(verify_request)
):
"""
Upload an image and get the uploaded image URL.
Requires Firebase App Check authentication.
"""
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
ip_address = request.client.host if request.client else None
if not file.content_type or not file.content_type.startswith("image/"):
log_api_call(
endpoint="/upload",
method="POST",
status_code=400,
error="File must be an image",
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=400, detail="File must be an image")
try:
# Generate unique filename
file_extension = file.filename.split('.')[-1] if file.filename else 'jpg'
image_id = f"{uuid.uuid4()}.{file_extension}"
file_path = UPLOAD_DIR / image_id
# Save uploaded file
img_bytes = await file.read()
file_size = len(img_bytes)
with open(file_path, "wb") as f:
f.write(img_bytes)
logger.info("Image uploaded: %s", image_id)
# Get base URL from settings or environment
base_url = os.getenv("BASE_URL", settings.BASE_URL)
if not base_url or base_url == "http://localhost:8000":
# Try to get from request
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
response_data = {
"success": True,
"image_id": image_id.replace(f".{file_extension}", ""),
"image_url": f"{base_url}/uploads/{image_id}",
"filename": image_id
}
# Log to MongoDB
log_image_upload(
image_id=image_id.replace(f".{file_extension}", ""),
filename=file.filename or image_id,
file_size=file_size,
content_type=file.content_type or "image/jpeg",
user_id=effective_user_id,
ip_address=ip_address
)
log_api_call(
endpoint="/upload",
method="POST",
status_code=200,
request_data={"filename": file.filename, "content_type": file.content_type},
response_data=response_data,
user_id=effective_user_id,
ip_address=ip_address
)
return JSONResponse(response_data)
except Exception as e:
error_msg = str(e)
logger.error("Error uploading image: %s", error_msg)
log_api_call(
endpoint="/upload",
method="POST",
status_code=500,
error=error_msg,
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=500, detail=f"Error uploading image: {error_msg}")
@app.post("/colorize")
async def colorize_api(
request: Request,
file: UploadFile = File(...),
positive_prompt: Optional[str] = None,
negative_prompt: Optional[str] = None,
seed: int = 123,
num_inference_steps: int = 8,
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
verified: bool = Depends(verify_request)
):
"""
Upload a grayscale image -> returns colorized image.
Uses SDXL + ControlNet with automatic captioning.
"""
import time
start_time = time.time()
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
ip_address = request.client.host if request.client else None
if inference_client is None:
log_api_call(
endpoint="/colorize",
method="POST",
status_code=503,
error="Inference API client not initialized",
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=503, detail="Inference API client not initialized")
if not file.content_type or not file.content_type.startswith("image/"):
log_api_call(
endpoint="/colorize",
method="POST",
status_code=400,
error="File must be an image",
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=400, detail="File must be an image")
try:
img_bytes = await file.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
logger.info("Colorizing image with SDXL + ControlNet...")
colorized, caption = colorize_image_sdxl(
image,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
seed=seed,
num_inference_steps=num_inference_steps
)
processing_time = time.time() - start_time
output_filename = f"{uuid.uuid4()}.png"
output_path = RESULT_DIR / output_filename
colorized.save(output_path, "PNG")
logger.info("Colorized image saved: %s", output_filename)
# Get base URL from settings or environment
base_url = os.getenv("BASE_URL", settings.BASE_URL)
if not base_url or base_url == "http://localhost:8000":
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
result_id = output_filename.replace(".png", "")
response_data = {
"success": True,
"result_id": result_id,
"download_url": f"{base_url}/results/{output_filename}",
"api_download_url": f"{base_url}/download/{result_id}",
"filename": output_filename,
"caption": caption
}
# Log to MongoDB (colorization_db -> colorizations)
log_colorization(
result_id=result_id,
prompt=positive_prompt,
model_type="sdxl",
processing_time=processing_time,
user_id=effective_user_id,
ip_address=ip_address,
status="success"
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=200,
request_data={
"filename": file.filename,
"positive_prompt": positive_prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"num_inference_steps": num_inference_steps
},
response_data=response_data,
user_id=effective_user_id,
ip_address=ip_address
)
return JSONResponse(response_data)
except Exception as e:
error_msg = str(e)
logger.error("Error colorizing image: %s", error_msg)
# Log failed colorization to colorizations collection
log_colorization(
result_id=None,
prompt=positive_prompt,
model_type="sdxl",
processing_time=None,
user_id=effective_user_id,
ip_address=ip_address,
status="failed",
error=error_msg
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=500,
error=error_msg,
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}")
@app.get("/download/{file_id}")
def download_result(
request: Request,
file_id: str,
verified: bool = Depends(verify_request)
):
"""Download colorized image by file ID"""
user_id = None
if hasattr(request, 'state') and hasattr(request.state, 'user'):
user_id = request.state.user.get("uid")
ip_address = request.client.host if request.client else None
filename = f"{file_id}.png"
path = RESULT_DIR / filename
if not path.exists():
log_api_call(
endpoint=f"/download/{file_id}",
method="GET",
status_code=404,
error="Result not found",
user_id=user_id,
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Result not found")
log_api_call(
endpoint=f"/download/{file_id}",
method="GET",
status_code=200,
request_data={"file_id": file_id},
user_id=user_id,
ip_address=ip_address
)
return FileResponse(path, media_type="image/png")
@app.get("/results/{filename}")
def get_result(request: Request, filename: str):
"""Public endpoint to access colorized images"""
ip_address = request.client.host if request.client else None
path = RESULT_DIR / filename
if not path.exists():
log_api_call(
endpoint=f"/results/{filename}",
method="GET",
status_code=404,
error="Result not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Result not found")
log_api_call(
endpoint=f"/results/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type="image/png")
# ========== Gradio Interface (Optional) ==========
def gradio_colorize(image, positive_prompt=None, negative_prompt=None, seed=123):
"""Gradio colorization function"""
if image is None:
return None, ""
try:
if inference_client is None:
return None, "Inference API client not initialized"
colorized, caption = colorize_image_sdxl(
image,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
seed=seed
)
return colorized, caption
except Exception as e:
logger.error("Gradio colorization error: %s", str(e))
return None, str(e)
title = "🎨 Text-Guided Image Colorization"
description = "Upload a grayscale image and generate a color version using Hugging Face Inference API (fal-ai provider)."
iface = gr.Interface(
fn=gradio_colorize,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Positive Prompt", placeholder="Enter details to enhance the caption"),
gr.Textbox(label="Negative Prompt", value=settings.NEGATIVE_PROMPT),
gr.Slider(0, 1000, 123, label="Seed")
],
outputs=[
gr.Image(type="pil", label="Colorized Image"),
gr.Textbox(label="Caption")
],
title=title,
description=description,
)
# Mount Gradio app at root
app = gr.mount_gradio_app(app, iface, path="/")
# ========== Run Server ==========
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)
#==========================================================================================
#main_fastapi.py
#==========================================================================================
"""
FastAPI application for FastAI GAN Image Colorization
with Firebase Authentication and Gradio UI
"""
import os
# Set environment variables BEFORE any imports
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/hf_cache"
os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache"
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib_config"
import io
import uuid
import logging
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Request, Form
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
import firebase_admin
from firebase_admin import credentials, app_check, auth as firebase_auth
from PIL import Image
import torch
import uvicorn
import gradio as gr
import numpy as np
import cv2
# FastAI imports
from fastai.vision.all import *
from huggingface_hub import from_pretrained_fastai
from app.config import settings
from app.pytorch_colorizer import PyTorchColorizer
from app.database import (
get_database,
log_api_call,
log_image_upload,
log_colorization,
log_media_click,
close_connection,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Create writable directories
Path("/tmp/hf_cache").mkdir(parents=True, exist_ok=True)
Path("/tmp/matplotlib_config").mkdir(parents=True, exist_ok=True)
Path("/tmp/colorize_uploads").mkdir(parents=True, exist_ok=True)
Path("/tmp/colorize_results").mkdir(parents=True, exist_ok=True)
# Initialize FastAPI app
app = FastAPI(
title="FastAI Image Colorizer API",
description="Image colorization using FastAI GAN model with Firebase authentication",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Firebase Admin SDK
firebase_cred_path = os.getenv("FIREBASE_CREDENTIALS_PATH", "/tmp/firebase-adminsdk.json")
if os.path.exists(firebase_cred_path):
try:
cred = credentials.Certificate(firebase_cred_path)
firebase_admin.initialize_app(cred)
logger.info("Firebase Admin SDK initialized")
except Exception as e:
logger.warning("Failed to initialize Firebase: %s", str(e))
try:
firebase_admin.initialize_app()
except:
pass
else:
logger.warning("Firebase credentials file not found. App Check will be disabled.")
try:
firebase_admin.initialize_app()
except:
pass
# Storage directories
UPLOAD_DIR = Path("/tmp/colorize_uploads")
RESULT_DIR = Path("/tmp/colorize_results")
# Mount static files
app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
app.mount("/uploads", StaticFiles(directory=str(UPLOAD_DIR)), name="uploads")
# Initialize FastAI model
learn = None
pytorch_colorizer = None
model_load_error: Optional[str] = None
model_type: str = "none" # "fastai", "pytorch", or "none"
@app.on_event("startup")
async def startup_event():
"""Load FastAI or PyTorch model on startup and initialize MongoDB"""
global learn, pytorch_colorizer, model_load_error, model_type
# Initialize MongoDB
try:
db = get_database()
if db is not None:
logger.info("✅ MongoDB initialized successfully!")
except Exception as e:
logger.warning("⚠️ MongoDB initialization failed: %s", str(e))
model_id = os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model")
# Try FastAI first
try:
logger.info("🔄 Attempting to load FastAI GAN Colorization Model: %s", model_id)
learn = from_pretrained_fastai(model_id)
logger.info("✅ FastAI model loaded successfully!")
model_type = "fastai"
model_load_error = None
return
except Exception as e:
error_msg = str(e)
logger.warning("⚠️ FastAI model loading failed: %s. Trying PyTorch fallback...", error_msg)
# Fallback to PyTorch
try:
logger.info("🔄 Attempting to load PyTorch GAN Colorization Model: %s", model_id)
pytorch_colorizer = PyTorchColorizer(model_id=model_id, model_filename="generator.pt")
logger.info("✅ PyTorch model loaded successfully!")
model_type = "pytorch"
model_load_error = None
except Exception as e:
error_msg = str(e)
logger.error("❌ Failed to load both FastAI and PyTorch models: %s", error_msg)
model_load_error = error_msg
model_type = "none"
# Don't raise - allow health check to work
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
global learn, pytorch_colorizer
if learn:
del learn
if pytorch_colorizer:
del pytorch_colorizer
close_connection()
logger.info("Application shutdown")
def _extract_bearer_token(authorization_header: str | None) -> str | None:
if not authorization_header:
return None
parts = authorization_header.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1].strip()
return None
async def verify_request(request: Request):
"""
Verify Firebase authentication
Accept either:
- Firebase Auth id_token via Authorization: Bearer <id_token>
- Firebase App Check token via X-Firebase-AppCheck (when ENABLE_APP_CHECK=true)
"""
# If Firebase is not initialized or auth is explicitly disabled, allow
if not firebase_admin._apps or os.getenv("DISABLE_AUTH", "false").lower() == "true":
return True
# Try Firebase Auth id_token first if present
bearer = _extract_bearer_token(request.headers.get("Authorization"))
if bearer:
try:
decoded = firebase_auth.verify_id_token(bearer)
request.state.user = decoded
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
return True
except Exception as e:
logger.warning("Auth token verification failed: %s", str(e))
# If App Check is enabled, require valid App Check token
if settings.ENABLE_APP_CHECK:
app_check_token = request.headers.get("X-Firebase-AppCheck")
if not app_check_token:
raise HTTPException(status_code=401, detail="Missing App Check token")
try:
app_check_claims = app_check.verify_token(app_check_token)
logger.info("App Check token verified for: %s", app_check_claims.get("app_id"))
return True
except Exception as e:
logger.warning("App Check token verification failed: %s", str(e))
raise HTTPException(status_code=401, detail="Invalid App Check token")
# Neither token required nor provided → allow (App Check disabled)
return True
def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]:
"""Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click)."""
if supplied_user_id and supplied_user_id.strip():
return supplied_user_id.strip()
return None
@app.get("/api")
async def api_info(request: Request):
"""API info endpoint"""
response_data = {
"app": "FastAI Image Colorizer API",
"version": "1.0.0",
"health": "/health",
"colorize": "/colorize",
"gradio": "/"
}
# Log API call
user_id = None
if hasattr(request, 'state') and hasattr(request.state, 'user'):
user_id = request.state.user.get("uid")
log_api_call(
endpoint="/api",
method="GET",
status_code=200,
response_data=response_data,
user_id=user_id,
ip_address=request.client.host if request.client else None
)
return response_data
@app.get("/health")
async def health_check(request: Request):
"""Health check endpoint"""
model_loaded = (learn is not None) or (pytorch_colorizer is not None)
response = {
"status": "healthy",
"model_loaded": model_loaded,
"model_type": model_type,
"model_id": os.getenv("MODEL_ID", "Hammad712/GAN-Colorization-Model"),
"using_fallback": not model_loaded
}
if model_load_error:
response["model_error"] = model_load_error
response["message"] = "Model failed to load. Using fallback colorization method."
elif not model_loaded:
response["message"] = "No model loaded. Using fallback colorization method."
else:
response["message"] = f"Model loaded successfully ({model_type})"
# Log API call
log_api_call(
endpoint="/health",
method="GET",
status_code=200,
response_data=response,
ip_address=request.client.host if request.client else None
)
return response
def simple_colorize_fallback(image: Image.Image) -> Image.Image:
"""
Enhanced fallback colorization using LAB color space with better color hints
This provides basic colorization when the model doesn't load
Note: This is a simple heuristic-based approach and won't match trained models
"""
# Convert to LAB color space
if image.mode != "RGB":
image = image.convert("RGB")
# Convert to numpy array
img_array = np.array(image)
original_shape = img_array.shape
# Convert RGB to LAB
lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
# Split channels
l, a, b = cv2.split(lab)
# Enhance lightness with CLAHE (Contrast Limited Adaptive Histogram Equalization)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
l_enhanced = clahe.apply(l)
# Add intelligent color hints based on image characteristics
# Analyze the grayscale image to determine color hints
l_normalized = l.astype(np.float32) / 255.0
# Create color hints: warmer tones for mid-brightness areas
# a channel: green-red axis (positive = red, negative = green)
# b channel: blue-yellow axis (positive = yellow, negative = blue)
# Add warm tones (slight red and yellow bias) based on brightness
# Darker areas get cooler tones, mid-brightness gets warmer
brightness_mask = np.clip((l_normalized - 0.3) * 2, 0, 1) # Emphasize mid-brightness
# Add color hints: warm tones for skin/faces, cooler for shadows
a_hint = np.clip(a.astype(np.float32) + brightness_mask * 8 + (1 - brightness_mask) * 2, 0, 255).astype(np.uint8)
b_hint = np.clip(b.astype(np.float32) + brightness_mask * 12 + (1 - brightness_mask) * 3, 0, 255).astype(np.uint8)
# Merge channels and convert back to RGB
lab_colored = cv2.merge([l_enhanced, a_hint, b_hint])
colored_rgb = cv2.cvtColor(lab_colored, cv2.COLOR_LAB2RGB)
# Apply slight saturation boost
hsv = cv2.cvtColor(colored_rgb, cv2.COLOR_RGB2HSV)
hsv[:, :, 1] = np.clip(hsv[:, :, 1].astype(np.float32) * 1.2, 0, 255).astype(np.uint8)
colored_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
return Image.fromarray(colored_rgb)
def colorize_pil(image: Image.Image) -> Image.Image:
"""Run model prediction and return colorized image"""
# Try FastAI first
if learn is not None:
if image.mode != "RGB":
image = image.convert("RGB")
pred = learn.predict(image)
# Handle different return types from FastAI
if isinstance(pred, (list, tuple)):
colorized = pred[0] if len(pred) > 0 else image
else:
colorized = pred
# Ensure we have a PIL Image
if not isinstance(colorized, Image.Image):
if isinstance(colorized, torch.Tensor):
# Convert tensor to PIL
if colorized.dim() == 4:
colorized = colorized[0]
if colorized.dim() == 3:
colorized = colorized.permute(1, 2, 0).cpu()
if colorized.dtype in (torch.float32, torch.float16):
colorized = torch.clamp(colorized, 0, 1)
colorized = (colorized * 255).byte()
colorized = Image.fromarray(colorized.numpy(), 'RGB')
else:
raise ValueError(f"Unexpected tensor shape: {colorized.shape}")
else:
raise ValueError(f"Unexpected prediction type: {type(colorized)}")
if colorized.mode != "RGB":
colorized = colorized.convert("RGB")
return colorized
# Fallback to PyTorch
elif pytorch_colorizer is not None:
return pytorch_colorizer.colorize(image)
else:
# Final fallback: simple colorization
logger.info("No model loaded, using enhanced colorization fallback (LAB color space method)")
return simple_colorize_fallback(image)
@app.post("/colorize")
async def colorize_api(
request: Request,
file: UploadFile = File(...),
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
verified: bool = Depends(verify_request)
):
"""
Upload a black & white image -> returns colorized image.
Requires Firebase authentication unless DISABLE_AUTH=true
"""
import time
start_time = time.time()
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
ip_address = request.client.host if request.client else None
# Allow fallback colorization even if model isn't loaded
# if learn is None and pytorch_colorizer is None:
# raise HTTPException(status_code=503, detail="Colorization model not loaded")
if not file.content_type or not file.content_type.startswith("image/"):
log_api_call(
endpoint="/colorize",
method="POST",
status_code=400,
error="File must be an image",
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=400, detail="File must be an image")
try:
img_bytes = await file.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
logger.info("Colorizing image...")
colorized = colorize_pil(image)
processing_time = time.time() - start_time
output_filename = f"{uuid.uuid4()}.png"
output_path = RESULT_DIR / output_filename
colorized.save(output_path, "PNG")
logger.info("Colorized image saved: %s", output_filename)
result_id = output_filename.replace(".png", "")
# Log to MongoDB (colorization_db -> colorizations)
log_colorization(
result_id=result_id,
model_type=model_type,
processing_time=processing_time,
user_id=effective_user_id,
ip_address=ip_address,
status="success"
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=200,
request_data={"filename": file.filename, "content_type": file.content_type},
response_data={"result_id": result_id, "filename": output_filename},
user_id=effective_user_id,
ip_address=ip_address
)
# Return the image file
return FileResponse(
output_path,
media_type="image/png",
filename=f"colorized_{output_filename}"
)
except Exception as e:
error_msg = str(e)
logger.error("Error colorizing image: %s", error_msg)
# Log failed colorization to colorizations collection
log_colorization(
result_id=None,
model_type=model_type,
processing_time=None,
user_id=effective_user_id,
ip_address=ip_address,
status="failed",
error=error_msg
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=500,
error=error_msg,
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}")
# ==========================================================
# Gradio Interface (for Space UI)
# ==========================================================
def gradio_colorize(image):
"""Gradio colorization function"""
if image is None:
return None
try:
# Always try to colorize, even with fallback
return colorize_pil(image)
except Exception as e:
logger.error("Gradio colorization error: %s", str(e))
return None
title = "🎨 Image Colorizer"
description = "Upload a black & white photo to generate a colorized version. Uses AI model when available, otherwise uses enhanced colorization fallback."
iface = gr.Interface(
fn=gradio_colorize,
inputs=gr.Image(type="pil", label="Upload B&W Image"),
outputs=gr.Image(type="pil", label="Colorized Image"),
title=title,
description=description,
)
# Mount Gradio app at root (this will be the Space UI)
# Note: This will override the root endpoint, so use /api for API info
app = gr.mount_gradio_app(app, iface, path="/")
# ==========================================================
# Run Server
# ==========================================================
if __name__ == "__main__":
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)