LogicGoInfotechSpaces's picture
Update app/main.py
002eda7 verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Header, Request, Form, Depends, Body
from typing import Optional, Tuple
from fastapi.responses import FileResponse
from huggingface_hub import hf_hub_download
import uuid
import os
import io
import json
import logging
import httpx
from PIL import Image
import torch
from torchvision import transforms
import numpy as np
from pydantic import BaseModel, EmailStr
from app.database import (
get_database,
log_api_call,
log_image_upload,
log_colorization,
log_media_click,
close_connection,
)
try:
import firebase_admin
from firebase_admin import auth as firebase_auth, app_check, credentials
except ImportError:
firebase_admin = None
firebase_auth = None
app_check = None
credentials = None
# Import CCO colorizers
try:
from app.colorizers import eccv16, siggraph17
from app.colorizers.util import preprocess_img, postprocess_tens
CCO_AVAILABLE = True
except ImportError as e:
print(f"⚠️ CCO colorizers not available: {e}")
CCO_AVAILABLE = False
logger = logging.getLogger(__name__)
# -------------------------------------------------
# 🚀 FastAPI App
# -------------------------------------------------
app = FastAPI(title="Text-Guided Image Colorization API")
# CORS middleware
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------------------------------------
# 🔐 Firebase Initialization (ENV-based)
# -------------------------------------------------
if firebase_admin:
try:
firebase_json = os.getenv("FIREBASE_CREDENTIALS")
if firebase_json:
print("🔥 Loading Firebase credentials from ENV...")
firebase_dict = json.loads(firebase_json)
cred = credentials.Certificate(firebase_dict)
firebase_admin.initialize_app(cred)
else:
print("⚠️ No Firebase credentials found. Firebase disabled.")
except Exception as e:
print("❌ Firebase initialization failed:", e)
else:
print("⚠️ Firebase Admin SDK not available. Firebase features disabled.")
# -------------------------------------------------
# 📁 Directories (FIXED FOR HUGGINGFACE SPACES)
# -------------------------------------------------
UPLOAD_DIR = "/tmp/uploads"
RESULTS_DIR = "/tmp/results"
COMPRESSED_DIR = "/tmp/compressed"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(COMPRESSED_DIR, exist_ok=True)
MEDIA_CLICK_DEFAULT_CATEGORY = os.getenv("DEFAULT_CATEGORY_FALLBACK", "69368fcd2e46bd68ae1889b2")
# -------------------------------------------------
# 🧠 Load GAN Colorization Model
# -------------------------------------------------
MODEL_REPO = "Hammad712/GAN-Colorization-Model"
MODEL_FILENAME = "generator.pt"
print("⬇️ Downloading GAN model...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
print("📦 Loading GAN model weights...")
state_dict = torch.load(model_path, map_location="cpu")
# NOTE: Replace with real model architecture
# from model import ColorizeNet
# model = ColorizeNet()
# model.load_state_dict(state_dict)
# model.eval()
# -------------------------------------------------
# 🧠 Load CCO Colorization Models
# -------------------------------------------------
cco_models = {}
if CCO_AVAILABLE:
print("📦 Loading CCO models...")
try:
cco_models["eccv16"] = eccv16(pretrained=True).eval()
cco_models["siggraph17"] = siggraph17(pretrained=True).eval()
print("✅ CCO models loaded successfully!")
except Exception as e:
print(f"⚠️ Failed to load CCO models: {e}")
CCO_AVAILABLE = False
def colorize_image_gan(img: Image.Image):
""" GAN colorizer (dummy implementation - replace with real model.predict) """
transform = transforms.ToTensor()
tensor = transform(img.convert("L")).unsqueeze(0)
tensor = tensor.repeat(1, 3, 1, 1)
output_img = transforms.ToPILImage()(tensor.squeeze())
return output_img
def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"):
""" CCO colorizer using eccv16 or siggraph17 model """
if not CCO_AVAILABLE:
raise ValueError("CCO models are not available")
if model_name not in ["eccv16", "siggraph17"]:
model_name = "eccv16" # Default to eccv16
model = cco_models.get(model_name)
if model is None:
raise ValueError(f"CCO model '{model_name}' not loaded")
# Ensure image is RGB
if img.mode != "RGB":
img = img.convert("RGB")
# Convert PIL Image to numpy array
oimg = np.asarray(img)
if oimg.ndim == 2:
oimg = np.tile(oimg[:,:,None], 3)
# Preprocess image
(tens_l_orig, tens_l_rs) = preprocess_img(oimg)
# Run model inference
with torch.no_grad():
out_ab = model(tens_l_rs)
# Postprocess output (returns RGB in [0, 1] range)
output_rgb = postprocess_tens(tens_l_orig, out_ab)
# Clamp values to [0, 1] and convert to uint8
output_rgb = np.clip(output_rgb, 0, 1)
output_array = (output_rgb * 255).astype(np.uint8)
# Convert numpy array back to PIL Image
output_img = Image.fromarray(output_array, 'RGB')
return output_img
def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"):
"""
Colorize image using specified model
Args:
img: PIL Image to colorize
model_type: "gan" or "cco"
cco_model: "eccv16" or "siggraph17" (only used if model_type is "cco")
Returns:
Colorized PIL Image
"""
if model_type == "cco":
return colorize_image_cco(img, cco_model)
else:
return colorize_image_gan(img)
def compress_image_to_target_size(img: Image.Image, target_size_mb: float = 2.5, min_size_mb: float = 2.0, max_size_mb: float = 3.0) -> Tuple[Image.Image, str]:
"""
Compress image to target file size (2-3MB) by adjusting quality and dimensions
Args:
img: PIL Image to compress
target_size_mb: Target file size in MB (default: 2.5MB)
min_size_mb: Minimum acceptable file size in MB (default: 2.0MB)
max_size_mb: Maximum acceptable file size in MB (default: 3.0MB)
Returns:
Tuple of (compressed PIL Image, path to saved compressed image)
"""
import tempfile
# Ensure image is RGB
if img.mode != "RGB":
img = img.convert("RGB")
# Get original dimensions
original_size = img.size
current_img = img.copy()
# Start with high quality and reduce if needed
quality = 85
scale_factor = 1.0
# Try different quality levels and scales to achieve target size
for attempt in range(20): # Max 20 attempts
# Create temporary file to check size
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
temp_path = tmp_file.name
try:
# Resize if needed
if scale_factor < 1.0:
new_width = int(original_size[0] * scale_factor)
new_height = int(original_size[1] * scale_factor)
resized_img = current_img.resize((new_width, new_height), Image.Resampling.LANCZOS)
else:
resized_img = current_img
# Save with current quality
resized_img.save(temp_path, "JPEG", quality=quality, optimize=True)
# Check file size
file_size_bytes = os.path.getsize(temp_path)
file_size_mb = file_size_bytes / (1024 * 1024)
logger.info("Compression attempt %d: quality=%d, scale=%.2f, size=%.2f MB",
attempt + 1, quality, scale_factor, file_size_mb)
# If size is within target range, we're done
if min_size_mb <= file_size_mb <= max_size_mb:
logger.info("Target size achieved: %.2f MB", file_size_mb)
return resized_img, temp_path
# If too large, reduce quality or scale
if file_size_mb > max_size_mb:
if quality > 30:
quality -= 5 # Reduce quality
elif scale_factor > 0.5:
scale_factor -= 0.05 # Reduce dimensions
quality = 75 # Reset quality when scaling
else:
# Already at minimum, accept current result
logger.warning("Reached minimum compression, file size: %.2f MB", file_size_mb)
return resized_img, temp_path
# If too small but still reasonable (above 1MB), try to increase quality slightly
elif file_size_mb < min_size_mb:
if file_size_mb < 1.0:
# Very small file, increase quality more aggressively
if quality < 90:
quality += 5
elif scale_factor < 1.0:
# Increase scale if we reduced it
scale_factor = min(1.0, scale_factor + 0.1)
elif quality < 95:
# Close to target, fine-tune quality
quality += 2
if quality > 95:
quality = 95
else:
# Already at max quality and scale, accept current result
logger.info("File size %.2f MB is below minimum but at max quality, accepting", file_size_mb)
return resized_img, temp_path
finally:
# Clean up temp file if we're continuing
if os.path.exists(temp_path) and attempt < 19:
try:
os.unlink(temp_path)
except:
pass
# Return the last attempt's result
return resized_img, temp_path
# -------------------------------------------------
# 🗄️ MongoDB Initialization
# -------------------------------------------------
@app.on_event("startup")
async def startup_event():
"""Initialize MongoDB on startup"""
try:
db = get_database()
if db is not None:
print("✅ MongoDB initialized successfully!")
except Exception as e:
print(f"⚠️ MongoDB initialization failed: {e}")
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
close_connection()
print("Application shutdown")
# -------------------------------------------------
# 🩺 Health Check
# -------------------------------------------------
@app.get("/health")
def health_check(request: Request):
response = {
"status": "healthy",
"model_loaded": True,
"model_type": "hf_inference_api",
"provider": "fal-ai"
}
# Log API call
log_api_call(
endpoint="/health",
method="GET",
status_code=200,
response_data=response,
ip_address=request.client.host if request.client else None
)
return response
# -------------------------------------------------
# 🔐 Auth Models
# -------------------------------------------------
class RegisterRequest(BaseModel):
email: EmailStr
password: str
display_name: Optional[str] = None
class LoginRequest(BaseModel):
email: EmailStr
password: str
class TokenResponse(BaseModel):
id_token: str
refresh_token: Optional[str] = None
expires_in: int
token_type: str = "Bearer"
user: dict
# -------------------------------------------------
# 🔐 Firebase Token Validator & Auth Helpers
# -------------------------------------------------
def _extract_bearer_token(authorization_header: Optional[str]) -> Optional[str]:
if not authorization_header:
return None
parts = authorization_header.split(" ", 1)
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1].strip()
return None
async def verify_request(request: Request):
"""
Verify Firebase authentication.
Priority:
1. Firebase App Check token (X-Firebase-AppCheck header) - Primary method
2. Firebase Auth ID token (Authorization: Bearer header) - Fallback for auth endpoints
"""
if not firebase_admin:
return True
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps:
return True
# Primary: Check Firebase App Check token (X-Firebase-AppCheck header)
app_check_token = request.headers.get("X-Firebase-AppCheck") or request.headers.get("x-firebase-appcheck")
if app_check_token and app_check:
try:
app_check_claims = app_check.verify_token(app_check_token)
logger.info("App Check token verified for: %s", app_check_claims.get("app_id"))
return True
except Exception as e:
logger.warning("App Check token verification failed: %s", str(e))
# Secondary: Check Firebase Auth ID token (Authorization: Bearer header)
bearer = _extract_bearer_token(request.headers.get("Authorization"))
if bearer and firebase_auth:
try:
decoded = firebase_auth.verify_id_token(bearer)
request.state.user = decoded
logger.info("Firebase Auth id_token verified for uid: %s", decoded.get("uid"))
return True
except Exception as e:
logger.warning("Auth token verification failed: %s", str(e))
# If no valid token, allow (for public endpoints)
return True
from firebase_admin import app_check
from fastapi import HTTPException
import os
def verify_app_check_token(
token: str | None,
*,
required: bool = False
):
"""
If required=False:
- Missing token is allowed
- Invalid token is rejected
If required=True:
- Missing OR invalid token is rejected
"""
# Token missing
if not token:
if required:
raise HTTPException(
status_code=401,
detail="Firebase App Check token missing"
)
return True # OPTIONAL → allow request
# Token present → must be valid
try:
app_check.verify_token(token)
return True
except Exception as e:
raise HTTPException(
status_code=401,
detail="Invalid Firebase App Check token"
)
def _resolve_user_id(request: Request, supplied_user_id: Optional[str]) -> Optional[str]:
"""Return supplied user_id if provided and not empty, otherwise None (will auto-generate in log_media_click)."""
if supplied_user_id and supplied_user_id.strip():
return supplied_user_id.strip()
return None
# -------------------------------------------------
# 🔐 Auth Endpoints
# -------------------------------------------------
@app.post("/auth/register", response_model=TokenResponse)
async def register_user(user_data: RegisterRequest):
"""
Register a new user with email and password.
Returns Firebase ID token for immediate use.
"""
if not firebase_admin:
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available")
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
if not firebase_auth:
raise HTTPException(status_code=503, detail="Firebase Auth not available")
try:
# Create user using Firebase Admin SDK
user_record = firebase_auth.create_user(
email=user_data.email,
password=user_data.password,
display_name=user_data.display_name,
email_verified=False
)
# Generate custom token that client can exchange for ID token
custom_token = firebase_auth.create_custom_token(user_record.uid)
logger.info("User registered: %s (uid: %s)", user_data.email, user_record.uid)
return TokenResponse(
id_token=custom_token.decode('utf-8') if isinstance(custom_token, bytes) else custom_token,
token_type="Bearer",
expires_in=3600,
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name or "",
"email_verified": user_record.email_verified
}
)
except Exception as e:
error_msg = str(e)
if "already exists" in error_msg.lower() or "email" in error_msg.lower():
raise HTTPException(status_code=400, detail="Email already registered")
logger.error("Registration error: %s", error_msg)
raise HTTPException(status_code=500, detail=f"Registration failed: {error_msg}")
@app.post("/auth/login", response_model=TokenResponse)
async def login_user(credentials: LoginRequest):
"""
Login with email and password.
Uses Firebase REST API to authenticate and get ID token.
"""
if not firebase_admin:
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available")
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
if not firebase_auth:
raise HTTPException(status_code=503, detail="Firebase Auth not available")
# Firebase REST API endpoint for email/password authentication
from app.config import settings
firebase_api_key = os.getenv("FIREBASE_API_KEY") or getattr(settings, 'FIREBASE_API_KEY', '')
if not firebase_api_key:
# Fallback: verify user exists and return custom token
try:
user_record = firebase_auth.get_user_by_email(credentials.email)
custom_token = firebase_auth.create_custom_token(user_record.uid)
logger.info("User login: %s (uid: %s)", credentials.email, user_record.uid)
return TokenResponse(
id_token=custom_token.decode('utf-8') if isinstance(custom_token, bytes) else custom_token,
token_type="Bearer",
expires_in=3600,
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name or "",
"email_verified": user_record.email_verified
}
)
except Exception as e:
error_msg = str(e)
if "not found" in error_msg.lower():
raise HTTPException(status_code=401, detail="Invalid email or password")
logger.error("Login error: %s", error_msg)
raise HTTPException(status_code=500, detail=f"Login failed: {error_msg}")
# Use Firebase REST API for proper authentication
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"https://identitytoolkit.googleapis.com/v1/accounts:signInWithPassword?key={firebase_api_key}",
json={
"email": credentials.email,
"password": credentials.password,
"returnSecureToken": True
}
)
if response.status_code != 200:
error_data = response.json()
error_msg = error_data.get("error", {}).get("message", "Authentication failed")
raise HTTPException(status_code=401, detail=error_msg)
data = response.json()
logger.info("User login successful: %s", credentials.email)
# Get user details from Admin SDK
user_record = firebase_auth.get_user(data["localId"])
return TokenResponse(
id_token=data["idToken"],
refresh_token=data.get("refreshToken"),
expires_in=int(data.get("expiresIn", 3600)),
token_type="Bearer",
user={
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name or "",
"email_verified": user_record.email_verified
}
)
except httpx.HTTPError as e:
logger.error("HTTP error during login: %s", str(e))
raise HTTPException(status_code=500, detail="Authentication service unavailable")
except HTTPException:
raise
except Exception as e:
logger.error("Login error: %s", str(e))
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}")
@app.get("/auth/me")
async def get_current_user(request: Request, verified: bool = Depends(verify_request)):
"""Get current authenticated user information"""
if not firebase_admin:
raise HTTPException(status_code=503, detail="Firebase Admin SDK not available")
if not hasattr(firebase_admin, '_apps') or not firebase_admin._apps:
raise HTTPException(status_code=503, detail="Firebase not initialized")
if not firebase_auth:
raise HTTPException(status_code=503, detail="Firebase Auth not available")
# Get user from request state (set by verify_request)
if hasattr(request, 'state') and hasattr(request.state, 'user'):
user_data = request.state.user
uid = user_data.get("uid")
try:
user_record = firebase_auth.get_user(uid)
return {
"uid": user_record.uid,
"email": user_record.email,
"display_name": user_record.display_name or "",
"email_verified": user_record.email_verified,
}
except Exception as e:
logger.error("Error getting user: %s", str(e))
raise HTTPException(status_code=404, detail="User not found")
raise HTTPException(status_code=401, detail="Not authenticated")
@app.post("/auth/refresh")
async def refresh_token(refresh_token_param: str = Body(..., embed=True)):
"""Refresh Firebase ID token using refresh token"""
from app.config import settings
firebase_api_key = os.getenv("FIREBASE_API_KEY") or getattr(settings, 'FIREBASE_API_KEY', '')
if not firebase_api_key:
raise HTTPException(status_code=503, detail="Firebase API key not configured")
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"https://securetoken.googleapis.com/v1/token?key={firebase_api_key}",
json={
"grant_type": "refresh_token",
"refresh_token": refresh_token_param
}
)
if response.status_code != 200:
error_data = response.json()
error_msg = error_data.get("error", {}).get("message", "Token refresh failed")
raise HTTPException(status_code=401, detail=error_msg)
data = response.json()
return {
"id_token": data["id_token"],
"refresh_token": data.get("refresh_token"),
"expires_in": int(data.get("expires_in", 3600)),
"token_type": "Bearer"
}
except httpx.HTTPError as e:
logger.error("HTTP error during token refresh: %s", str(e))
raise HTTPException(status_code=500, detail="Token refresh service unavailable")
except HTTPException:
raise
except Exception as e:
logger.error("Token refresh error: %s", str(e))
raise HTTPException(status_code=500, detail=f"Token refresh failed: {str(e)}")
# -------------------------------------------------
# 📤 Upload Image
# -------------------------------------------------
@app.post("/upload")
async def upload_image(
request: Request,
file: UploadFile = File(...),
x_firebase_appcheck: str = Header(None),
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
):
verify_app_check_token(x_firebase_appcheck)
ip_address = request.client.host if request.client else None
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
if not file.content_type.startswith("image/"):
log_api_call(
endpoint="/upload",
method="POST",
status_code=400,
error="Invalid file type",
ip_address=ip_address
)
raise HTTPException(status_code=400, detail="Invalid file type")
image_id = f"{uuid.uuid4()}.jpg"
file_path = os.path.join(UPLOAD_DIR, image_id)
img_bytes = await file.read()
file_size = len(img_bytes)
with open(file_path, "wb") as f:
f.write(img_bytes)
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
response_data = {
"success": True,
"image_id": image_id.replace(".jpg", ""),
"file_url": f"{base_url}/uploads/{image_id}"
}
# Log to MongoDB
log_image_upload(
image_id=image_id.replace(".jpg", ""),
filename=file.filename or image_id,
file_size=file_size,
content_type=file.content_type or "image/jpeg",
user_id=effective_user_id,
ip_address=ip_address
)
log_api_call(
endpoint="/upload",
method="POST",
status_code=200,
request_data={"filename": file.filename, "content_type": file.content_type},
response_data=response_data,
user_id=effective_user_id,
ip_address=ip_address
)
log_media_click(
user_id=effective_user_id,
category_id=effective_category_id,
endpoint_path=str(request.url.path),
default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY,
)
return response_data
# -------------------------------------------------
# 🎨 Colorize Image
# -------------------------------------------------
@app.post("/colorize")
async def colorize(
request: Request,
file: UploadFile = File(...),
x_firebase_appcheck: str = Header(None),
user_id: Optional[str] = Form(None),
category_id: Optional[str] = Form(None),
categoryId: Optional[str] = Form(None),
model: Optional[str] = Form(None), # Model parameter: "gan", "cco", "cco-eccv16", "cco-siggraph17" (default: CCO if available)
):
import time
start_time = time.time()
verify_app_check_token(x_firebase_appcheck,required=False)
ip_address = request.client.host if request.client else None
effective_user_id = _resolve_user_id(request, user_id)
effective_category_id = (category_id or categoryId) if (category_id or categoryId) else None
if effective_category_id:
effective_category_id = effective_category_id.strip() if isinstance(effective_category_id, str) else effective_category_id
if not effective_category_id:
effective_category_id = None
# Parse model parameter
# Default to CCO if available, otherwise fallback to GAN
if CCO_AVAILABLE:
model_type = "cco"
cco_model = "eccv16"
model_type_for_log = "cco-eccv16"
else:
model_type = "gan"
model_type_for_log = "gan"
if model:
model = model.strip().lower()
if model == "gan":
# Use GAN model (dummy implementation - doesn't actually colorize)
model_type = "gan"
model_type_for_log = "gan"
elif model == "cco" or model.startswith("cco-"):
if not CCO_AVAILABLE:
error_msg = "CCO models are not available"
log_api_call(
endpoint="/colorize",
method="POST",
status_code=400,
error=error_msg,
ip_address=ip_address
)
log_colorization(
result_id=None,
model_type="cco",
processing_time=None,
user_id=effective_user_id,
ip_address=ip_address,
status="failed",
error=error_msg
)
raise HTTPException(status_code=400, detail=error_msg)
model_type = "cco"
if model == "cco-eccv16":
cco_model = "eccv16"
model_type_for_log = "cco-eccv16"
elif model == "cco-siggraph17":
cco_model = "siggraph17"
model_type_for_log = "cco-siggraph17"
else:
# Default to eccv16 if just "cco" is specified
cco_model = "eccv16"
model_type_for_log = "cco-eccv16"
else:
# Unknown model, use default (CCO if available, else GAN)
pass
if not file.content_type.startswith("image/"):
error_msg = "Invalid file type"
log_api_call(
endpoint="/colorize",
method="POST",
status_code=400,
error=error_msg,
ip_address=ip_address
)
# Log failed colorization
log_colorization(
result_id=None,
model_type=model_type_for_log,
processing_time=None,
user_id=effective_user_id,
ip_address=ip_address,
status="failed",
error=error_msg
)
raise HTTPException(status_code=400, detail=error_msg)
try:
img = Image.open(io.BytesIO(await file.read()))
# Ensure image is RGB
if img.mode != "RGB":
img = img.convert("RGB")
output_img = colorize_image(img, model_type=model_type, cco_model=cco_model)
# Ensure output is RGB
if output_img.mode != "RGB":
output_img = output_img.convert("RGB")
processing_time = time.time() - start_time
result_id = f"{uuid.uuid4()}.png"
output_path = os.path.join(RESULTS_DIR, result_id)
# Save original image as PNG (uncompressed) - this is what the model produces
output_img.save(output_path, "PNG")
# Create compressed version targeting 2-3MB file size
logger.info("Creating compressed version targeting 2-3MB file size...")
compressed_img, temp_compressed_path = compress_image_to_target_size(
output_img,
target_size_mb=2.5,
min_size_mb=2.0,
max_size_mb=3.0
)
# Move compressed image to final location
compressed_filename = result_id.replace(".png", "_compressed.jpg")
compressed_path = os.path.join(COMPRESSED_DIR, compressed_filename)
# If temp file exists, move it; otherwise save the compressed image
if os.path.exists(temp_compressed_path):
import shutil
shutil.move(temp_compressed_path, compressed_path)
else:
compressed_img.save(compressed_path, "JPEG", quality=75, optimize=True)
# Log compressed file size
compressed_size_mb = os.path.getsize(compressed_path) / (1024 * 1024)
original_size_mb = os.path.getsize(output_path) / (1024 * 1024)
logger.info("Original image size: %.2f MB, Compressed image size: %.2f MB",
original_size_mb, compressed_size_mb)
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
result_id_clean = result_id.replace(".png", "")
# Use a consistent caption text for all models
caption = "colorize this image with vibrant natural colors, high quality"
response_data = {
"success": True,
"result_id": result_id_clean,
"download_url": f"{base_url}/results/{result_id}",
"api_download_url": f"{base_url}/download/{result_id_clean}",
"filename": result_id,
"caption": caption,
"Compressed_Image_URL": f"{base_url}/compressed/{compressed_filename}"
}
# Log to MongoDB (colorization_db -> colorizations)
log_colorization(
result_id=result_id_clean,
model_type=model_type_for_log,
processing_time=processing_time,
user_id=effective_user_id,
ip_address=ip_address,
status="success"
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=200,
request_data={"filename": file.filename, "content_type": file.content_type, "model": model},
response_data=response_data,
user_id=effective_user_id,
ip_address=ip_address
)
log_media_click(
user_id=effective_user_id,
category_id=effective_category_id,
endpoint_path=str(request.url.path),
default_category_id=MEDIA_CLICK_DEFAULT_CATEGORY,
)
return response_data
except Exception as e:
error_msg = str(e)
logger.error("Error colorizing image: %s", error_msg)
# Log failed colorization to colorizations collection
log_colorization(
result_id=None,
model_type=model_type_for_log,
processing_time=None,
user_id=effective_user_id,
ip_address=ip_address,
status="failed",
error=error_msg
)
log_api_call(
endpoint="/colorize",
method="POST",
status_code=500,
error=error_msg,
user_id=effective_user_id,
ip_address=ip_address
)
raise HTTPException(status_code=500, detail=f"Error colorizing image: {error_msg}")
# -------------------------------------------------
# ⬇️ Download via API (Secure)
# -------------------------------------------------
@app.get("/download/{file_id}")
def download_result(
request: Request,
file_id: str,
x_firebase_appcheck: str = Header(None)
):
verify_app_check_token(x_firebase_appcheck)
ip_address = request.client.host if request.client else None
# Try PNG first, then JPG for backward compatibility
filename_png = f"{file_id}.png"
filename_jpg = f"{file_id}.jpg"
path_png = os.path.join(RESULTS_DIR, filename_png)
path_jpg = os.path.join(RESULTS_DIR, filename_jpg)
# Check which file exists
if os.path.exists(path_png):
filename = filename_png
path = path_png
media_type = "image/png"
elif os.path.exists(path_jpg):
filename = filename_jpg
path = path_jpg
media_type = "image/jpeg"
else:
filename = filename_png
path = path_png
media_type = "image/png"
if not os.path.exists(path):
log_api_call(
endpoint=f"/download/{file_id}",
method="GET",
status_code=404,
error="Result not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Result not found")
log_api_call(
endpoint=f"/download/{file_id}",
method="GET",
status_code=200,
request_data={"file_id": file_id},
ip_address=ip_address
)
return FileResponse(path, media_type=media_type)
# -------------------------------------------------
# 🌐 Public Result File
# -------------------------------------------------
@app.get("/results/{filename}")
def get_result(request: Request, filename: str):
ip_address = request.client.host if request.client else None
path = os.path.join(RESULTS_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/results/{filename}",
method="GET",
status_code=404,
error="Result not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Result not found")
# Determine media type based on file extension
if filename.lower().endswith('.png'):
media_type = "image/png"
elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'):
media_type = "image/jpeg"
else:
media_type = "image/png" # Default to PNG
log_api_call(
endpoint=f"/results/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type=media_type)
# -------------------------------------------------
# 🌐 Public Uploaded File
# -------------------------------------------------
@app.get("/uploads/{filename}")
def get_upload(request: Request, filename: str):
ip_address = request.client.host if request.client else None
path = os.path.join(UPLOAD_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/uploads/{filename}",
method="GET",
status_code=404,
error="File not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="File not found")
# Determine media type based on file extension
if filename.lower().endswith('.png'):
media_type = "image/png"
elif filename.lower().endswith('.jpg') or filename.lower().endswith('.jpeg'):
media_type = "image/jpeg"
else:
media_type = "image/jpeg" # Default to JPEG
log_api_call(
endpoint=f"/uploads/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type=media_type)
# -------------------------------------------------
# 🌐 Public Compressed File
# -------------------------------------------------
@app.get("/compressed/{filename}")
def get_compressed(request: Request, filename: str):
ip_address = request.client.host if request.client else None
path = os.path.join(COMPRESSED_DIR, filename)
if not os.path.exists(path):
log_api_call(
endpoint=f"/compressed/{filename}",
method="GET",
status_code=404,
error="Compressed file not found",
ip_address=ip_address
)
raise HTTPException(status_code=404, detail="Compressed file not found")
# Compressed images are JPEG
media_type = "image/jpeg"
log_api_call(
endpoint=f"/compressed/{filename}",
method="GET",
status_code=200,
request_data={"filename": filename},
ip_address=ip_address
)
return FileResponse(path, media_type=media_type)
@app.get("/")
async def root():
"""Root endpoint"""
return {
"success": True,
"message": "Image Colorization API",
"data": {
"version": "1.0.0",
"Product Name":"Beauty Camera - GlowCam AI Studio",
"Released By" : "LogicGo Infotech"
}
}