diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,1994 +1,1994 @@ -import os -import io -import json -import traceback -from datetime import datetime,timedelta -from typing import Optional -import time -import uuid -import boto3 -from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends -from fastapi.responses import StreamingResponse, JSONResponse -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from pymongo import MongoClient -import gridfs -from gridfs.errors import NoFile -from bson.objectid import ObjectId -from PIL import Image -from fastapi.concurrency import run_in_threadpool -import shutil -import firebase_admin -from firebase_admin import credentials, auth -from PIL import Image -from huggingface_hub import InferenceClient -# --------------------------------------------------------------------- -# Load Firebase Config from env (stringified JSON) -# --------------------------------------------------------------------- -firebase_config_json = os.getenv("firebase_config") -if not firebase_config_json: - raise RuntimeError("❌ Missing Firebase config in environment variable 'firebase_config'") +# import os +# import io +# import json +# import traceback +# from datetime import datetime,timedelta +# from typing import Optional +# import time +# import uuid +# import boto3 +# from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends +# from fastapi.responses import StreamingResponse, JSONResponse +# from fastapi.middleware.cors import CORSMiddleware +# from pydantic import BaseModel +# from pymongo import MongoClient +# import gridfs +# from gridfs.errors import NoFile +# from bson.objectid import ObjectId +# from PIL import Image +# from fastapi.concurrency import run_in_threadpool +# import shutil +# import firebase_admin +# from firebase_admin import credentials, auth +# from PIL import Image +# from huggingface_hub import InferenceClient +# # --------------------------------------------------------------------- +# # Load Firebase Config from env (stringified JSON) +# # --------------------------------------------------------------------- +# firebase_config_json = os.getenv("firebase_config") +# if not firebase_config_json: +# raise RuntimeError("❌ Missing Firebase config in environment variable 'firebase_config'") -try: - firebase_creds_dict = json.loads(firebase_config_json) - cred = credentials.Certificate(firebase_creds_dict) - firebase_admin.initialize_app(cred) -except Exception as e: - raise RuntimeError(f"Failed to initialize Firebase Admin SDK: {e}") +# try: +# firebase_creds_dict = json.loads(firebase_config_json) +# cred = credentials.Certificate(firebase_creds_dict) +# firebase_admin.initialize_app(cred) +# except Exception as e: +# raise RuntimeError(f"Failed to initialize Firebase Admin SDK: {e}") -# --------------------------------------------------------------------- -# Hugging Face setup -# --------------------------------------------------------------------- -HF_TOKEN = os.getenv("HF_TOKEN") -if not HF_TOKEN: - raise RuntimeError("HF_TOKEN not set in environment variables") +# # --------------------------------------------------------------------- +# # Hugging Face setup +# # --------------------------------------------------------------------- +# HF_TOKEN = os.getenv("HF_TOKEN") +# if not HF_TOKEN: +# raise RuntimeError("HF_TOKEN not set in environment variables") -hf_client = InferenceClient(token=HF_TOKEN) -# --------------------------------------------------------------------- -# MODEL SELECTION -# --------------------------------------------------------------------- -genai = None # ✅ IMPORTANT: module-level declaration -MODEL = os.getenv("IMAGE_MODEL", "GEMINI").upper() -GEMINI_FORCE_CATEGORY_ID = "69368e741224bcb6bdb98076" -# --------------------------------------------------------------------- -# Gemini setup (ONLY used if MODEL == "GEMINI") -# --------------------------------------------------------------------- -GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") -GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image") +# hf_client = InferenceClient(token=HF_TOKEN) +# # --------------------------------------------------------------------- +# # MODEL SELECTION +# # --------------------------------------------------------------------- +# genai = None # ✅ IMPORTANT: module-level declaration +# MODEL = os.getenv("IMAGE_MODEL", "GEMINI").upper() +# GEMINI_FORCE_CATEGORY_ID = "69368e741224bcb6bdb98076" +# # --------------------------------------------------------------------- +# # Gemini setup (ONLY used if MODEL == "GEMINI") +# # --------------------------------------------------------------------- +# GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") +# GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image") + +# # --------------------------------------------------------------------- +# # MongoDB setup +# # --------------------------------------------------------------------- +# # MONGODB_URI=os.getenv("MONGODB_URI") +# # DB_NAME = "polaroid_db" + +# # mongo = MongoClient(MONGODB_URI) +# # db = mongo[DB_NAME] +# # fs = gridfs.GridFS(db) +# # logs_collection = db["logs"] +# # ---------------- Logs MongoDB (NEW) ---------------- +# MONGODB_URI = os.getenv("MONGODB_URI") +# POLAROID_LOGS_MONGO_URL = os.getenv("POLAROID_LOGS_MONGO_URL") -# --------------------------------------------------------------------- -# MongoDB setup -# --------------------------------------------------------------------- -# MONGODB_URI=os.getenv("MONGODB_URI") # DB_NAME = "polaroid_db" +# # Main DB (images, gridfs) # mongo = MongoClient(MONGODB_URI) # db = mongo[DB_NAME] # fs = gridfs.GridFS(db) -# logs_collection = db["logs"] -# ---------------- Logs MongoDB (NEW) ---------------- -MONGODB_URI = os.getenv("MONGODB_URI") -POLAROID_LOGS_MONGO_URL = os.getenv("POLAROID_LOGS_MONGO_URL") -DB_NAME = "polaroid_db" +# # Separate Logs DB +# logs_client = None +# logs_collection = None -# Main DB (images, gridfs) -mongo = MongoClient(MONGODB_URI) -db = mongo[DB_NAME] -fs = gridfs.GridFS(db) +# if POLAROID_LOGS_MONGO_URL: +# logs_client = MongoClient(POLAROID_LOGS_MONGO_URL) +# logs_db = logs_client["logs"] # FORCE DB NAME +# logs_collection = logs_db["polaroid"] # FORCE COLLECTION +# else: +# raise RuntimeError("POLAROID_LOGS_MONGO_URL not set") +# # --------------------------------------------------------------------- +# # DigitalOcean Spaces setup +# # --------------------------------------------------------------------- +# DO_SPACES_KEY = os.getenv("DO_SPACES_KEY") +# DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET") +# DO_SPACES_REGION = os.getenv("DO_SPACES_REGION", "blr1") +# DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET") +# DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT", f"https://{DO_SPACES_REGION}.digitaloceanspaces.com") -# Separate Logs DB -logs_client = None -logs_collection = None +# if not DO_SPACES_KEY or not DO_SPACES_SECRET: +# raise RuntimeError("Missing DigitalOcean Spaces credentials in environment variables") -if POLAROID_LOGS_MONGO_URL: - logs_client = MongoClient(POLAROID_LOGS_MONGO_URL) - logs_db = logs_client["logs"] # FORCE DB NAME - logs_collection = logs_db["polaroid"] # FORCE COLLECTION -else: - raise RuntimeError("POLAROID_LOGS_MONGO_URL not set") -# --------------------------------------------------------------------- -# DigitalOcean Spaces setup -# --------------------------------------------------------------------- -DO_SPACES_KEY = os.getenv("DO_SPACES_KEY") -DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET") -DO_SPACES_REGION = os.getenv("DO_SPACES_REGION", "blr1") -DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET") -DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT", f"https://{DO_SPACES_REGION}.digitaloceanspaces.com") +# # Initialize S3 client for DigitalOcean Spaces +# s3_client = boto3.client( +# 's3', +# region_name=DO_SPACES_REGION, +# endpoint_url=DO_SPACES_ENDPOINT, +# aws_access_key_id=DO_SPACES_KEY, +# aws_secret_access_key=DO_SPACES_SECRET +# ) -if not DO_SPACES_KEY or not DO_SPACES_SECRET: - raise RuntimeError("Missing DigitalOcean Spaces credentials in environment variables") +# # --------------------------------------------------------------------- +# # FastAPI app setup +# # --------------------------------------------------------------------- +# app = FastAPI(title="Qwen Image Edit API with Firebase Auth") +# app.add_middleware( +# CORSMiddleware, +# allow_origins=["*"], +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) -# Initialize S3 client for DigitalOcean Spaces -s3_client = boto3.client( - 's3', - region_name=DO_SPACES_REGION, - endpoint_url=DO_SPACES_ENDPOINT, - aws_access_key_id=DO_SPACES_KEY, - aws_secret_access_key=DO_SPACES_SECRET -) +# # --------------------------------------------------------------------- +# # Auth dependency +# # --------------------------------------------------------------------- +# async def verify_firebase_token(request: Request): +# """Middleware-like dependency to verify Firebase JWT from Authorization header.""" +# auth_header = request.headers.get("Authorization") +# if not auth_header or not auth_header.startswith("Bearer "): +# raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") -# --------------------------------------------------------------------- -# FastAPI app setup -# --------------------------------------------------------------------- -app = FastAPI(title="Qwen Image Edit API with Firebase Auth") -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# id_token = auth_header.split("Bearer ")[1] +# try: +# decoded_token = auth.verify_id_token(id_token) +# request.state.user = decoded_token +# return decoded_token +# except Exception as e: +# raise HTTPException(status_code=401, detail=f"Invalid or expired Firebase token: {e}") -# --------------------------------------------------------------------- -# Auth dependency -# --------------------------------------------------------------------- -async def verify_firebase_token(request: Request): - """Middleware-like dependency to verify Firebase JWT from Authorization header.""" - auth_header = request.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Bearer "): - raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") +# # --------------------------------------------------------------------- +# # Models +# # --------------------------------------------------------------------- +# class HealthResponse(BaseModel): +# status: str +# db: str +# model: str - id_token = auth_header.split("Bearer ")[1] - try: - decoded_token = auth.verify_id_token(id_token) - request.state.user = decoded_token - return decoded_token - except Exception as e: - raise HTTPException(status_code=401, detail=f"Invalid or expired Firebase token: {e}") +# # --------------------- UTILS --------------------- +# def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image: +# """ +# Resize image to fit within max_size while keeping aspect ratio. +# """ +# if img.width > max_size[0] or img.height > max_size[1]: +# img.thumbnail(max_size, Image.ANTIALIAS) +# return img +# # --------------------------------------------------------------------- +# # Lazy Gemini Initialization +# # --------------------------------------------------------------------- +# _genai_initialized = False +# def init_gemini(): +# global _genai_initialized, genai -# --------------------------------------------------------------------- -# Models -# --------------------------------------------------------------------- -class HealthResponse(BaseModel): - status: str - db: str - model: str +# if _genai_initialized: +# return -# --------------------- UTILS --------------------- -def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image: - """ - Resize image to fit within max_size while keeping aspect ratio. - """ - if img.width > max_size[0] or img.height > max_size[1]: - img.thumbnail(max_size, Image.ANTIALIAS) - return img -# --------------------------------------------------------------------- -# Lazy Gemini Initialization -# --------------------------------------------------------------------- -_genai_initialized = False -def init_gemini(): - global _genai_initialized, genai +# if not GEMINI_API_KEY: +# raise RuntimeError("❌ GEMINI_API_KEY not set") - if _genai_initialized: - return +# import google.generativeai as genai +# genai.configure(api_key=GEMINI_API_KEY) - if not GEMINI_API_KEY: - raise RuntimeError("❌ GEMINI_API_KEY not set") - - import google.generativeai as genai - genai.configure(api_key=GEMINI_API_KEY) - - _genai_initialized = True +# _genai_initialized = True -def expiry(hours: int): - """Return UTC datetime for TTL expiration""" - return datetime.utcnow() + timedelta(hours=hours) +# def expiry(hours: int): +# """Return UTC datetime for TTL expiration""" +# return datetime.utcnow() + timedelta(hours=hours) -def prepare_image(file_bytes: bytes) -> Image.Image: - """ - Open image and resize if larger than 1024x1024 - """ - img = Image.open(io.BytesIO(file_bytes)).convert("RGB") +# def prepare_image(file_bytes: bytes) -> Image.Image: +# """ +# Open image and resize if larger than 1024x1024 +# """ +# img = Image.open(io.BytesIO(file_bytes)).convert("RGB") - # ✅ MIN SIZE CHECK - if img.width < 200 or img.height < 200: - raise HTTPException( - status_code=400, - detail="Image size is below 200x200 pixels. Please upload a larger image." - ) - img = Image.open(io.BytesIO(file_bytes)).convert("RGB") - img = resize_image_if_needed(img, max_size=(1024, 1024)) - return img +# # ✅ MIN SIZE CHECK +# if img.width < 200 or img.height < 200: +# raise HTTPException( +# status_code=400, +# detail="Image size is below 200x200 pixels. Please upload a larger image." +# ) +# img = Image.open(io.BytesIO(file_bytes)).convert("RGB") +# img = resize_image_if_needed(img, max_size=(1024, 1024)) +# return img -def upload_to_digitalocean(image_bytes: bytes, folder: str, filename: str) -> str: - """ - Upload image to DigitalOcean Spaces and return the public URL. - folder: 'source' or 'results' - """ - key = f"valentine/{folder}/{filename}" - try: - s3_client.put_object( - Bucket=DO_SPACES_BUCKET, - Key=key, - Body=image_bytes, - ContentType="image/jpeg" if filename.endswith('.jpg') else "image/png", - ACL='public-read' - ) - url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{key}" - return url - except Exception as e: - raise RuntimeError(f"Failed to upload to DigitalOcean Spaces: {e}") +# def upload_to_digitalocean(image_bytes: bytes, folder: str, filename: str) -> str: +# """ +# Upload image to DigitalOcean Spaces and return the public URL. +# folder: 'source' or 'results' +# """ +# key = f"valentine/{folder}/{filename}" +# try: +# s3_client.put_object( +# Bucket=DO_SPACES_BUCKET, +# Key=key, +# Body=image_bytes, +# ContentType="image/jpeg" if filename.endswith('.jpg') else "image/png", +# ACL='public-read' +# ) +# url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{key}" +# return url +# except Exception as e: +# raise RuntimeError(f"Failed to upload to DigitalOcean Spaces: {e}") -def generate_image_id() -> str: - """ - Generate a random image ID (same format as MongoDB ObjectId for consistency) - """ - return str(uuid.uuid4().hex[:24]) +# def generate_image_id() -> str: +# """ +# Generate a random image ID (same format as MongoDB ObjectId for consistency) +# """ +# return str(uuid.uuid4().hex[:24]) -MAX_COMPRESSED_SIZE = 2 * 1024 * 1024 # 2 MB -def compress_pil_image_to_2mb( - pil_img: Image.Image, - max_dim: int = 1280 -) -> bytes: - """ - Resize + compress PIL image to <= 2MB. - Returns JPEG bytes. - """ - img = pil_img.convert("RGB") +# MAX_COMPRESSED_SIZE = 2 * 1024 * 1024 # 2 MB +# def compress_pil_image_to_2mb( +# pil_img: Image.Image, +# max_dim: int = 1280 +# ) -> bytes: +# """ +# Resize + compress PIL image to <= 2MB. +# Returns JPEG bytes. +# """ +# img = pil_img.convert("RGB") - # Resize (maintain aspect ratio) - img.thumbnail((max_dim, max_dim), Image.LANCZOS) +# # Resize (maintain aspect ratio) +# img.thumbnail((max_dim, max_dim), Image.LANCZOS) - quality = 85 - buffer = io.BytesIO() +# quality = 85 +# buffer = io.BytesIO() - while quality >= 40: - buffer.seek(0) - buffer.truncate() +# while quality >= 40: +# buffer.seek(0) +# buffer.truncate() - img.save( - buffer, - format="JPEG", - quality=quality, - optimize=True, - progressive=True - ) +# img.save( +# buffer, +# format="JPEG", +# quality=quality, +# optimize=True, +# progressive=True +# ) - if buffer.tell() <= MAX_COMPRESSED_SIZE: - break +# if buffer.tell() <= MAX_COMPRESSED_SIZE: +# break - quality -= 5 +# quality -= 5 - return buffer.getvalue() +# return buffer.getvalue() -def run_image_generation( - image1: Image.Image, - prompt: str, - image2: Optional[Image.Image] = None, - force_model: Optional[str] = None -) -> Image.Image: - """ - Unified image generation interface. - QWEN -> merges images if image2 exists - GEMINI -> passes images separately - """ - - effective_model = (force_model or MODEL).upper() - - # ---------------- QWEN ---------------- - if effective_model == "QWEN": +# def run_image_generation( +# image1: Image.Image, +# prompt: str, +# image2: Optional[Image.Image] = None, +# force_model: Optional[str] = None +# ) -> Image.Image: +# """ +# Unified image generation interface. +# QWEN -> merges images if image2 exists +# GEMINI -> passes images separately +# """ - # ✅ Merge images ONLY for QWEN - if image2: - total_width = image1.width + image2.width - max_height = max(image1.height, image2.height) - merged = Image.new("RGB", (total_width, max_height)) - merged.paste(image1, (0, 0)) - merged.paste(image2, (image1.width, 0)) - else: - merged = image1 +# effective_model = (force_model or MODEL).upper() - return hf_client.image_to_image( - image=merged, - prompt=prompt, - model="Qwen/Qwen-Image-Edit" - ) +# # ---------------- QWEN ---------------- +# if effective_model == "QWEN": - # ---------------- GEMINI ---------------- - elif effective_model == "GEMINI": - init_gemini() - model = genai.GenerativeModel(GEMINI_IMAGE_MODEL) +# # ✅ Merge images ONLY for QWEN +# if image2: +# total_width = image1.width + image2.width +# max_height = max(image1.height, image2.height) +# merged = Image.new("RGB", (total_width, max_height)) +# merged.paste(image1, (0, 0)) +# merged.paste(image2, (image1.width, 0)) +# else: +# merged = image1 - parts = [prompt] +# return hf_client.image_to_image( +# image=merged, +# prompt=prompt, +# model="Qwen/Qwen-Image-Edit" +# ) - def add_image(img: Image.Image): - buf = io.BytesIO() - img.save(buf, format="PNG") - buf.seek(0) - parts.append({ - "mime_type": "image/png", - "data": buf.getvalue() - }) +# # ---------------- GEMINI ---------------- +# elif effective_model == "GEMINI": +# init_gemini() +# model = genai.GenerativeModel(GEMINI_IMAGE_MODEL) - add_image(image1) +# parts = [prompt] - if image2: - add_image(image2) +# def add_image(img: Image.Image): +# buf = io.BytesIO() +# img.save(buf, format="PNG") +# buf.seek(0) +# parts.append({ +# "mime_type": "image/png", +# "data": buf.getvalue() +# }) - response = model.generate_content(parts) +# add_image(image1) - # -------- SAFE IMAGE EXTRACTION -------- - import base64 +# if image2: +# add_image(image2) - image_bytes = None - for candidate in response.candidates: - for part in candidate.content.parts: - if hasattr(part, "inline_data") and part.inline_data: - data = part.inline_data.data - image_bytes = ( - data if isinstance(data, (bytes, bytearray)) - else base64.b64decode(data) - ) - break - if image_bytes: - break +# response = model.generate_content(parts) - if not image_bytes: - raise RuntimeError("Gemini did not return an image") +# # -------- SAFE IMAGE EXTRACTION -------- +# import base64 - img = Image.open(io.BytesIO(image_bytes)) - img.verify() - return Image.open(io.BytesIO(image_bytes)).convert("RGB") +# image_bytes = None +# for candidate in response.candidates: +# for part in candidate.content.parts: +# if hasattr(part, "inline_data") and part.inline_data: +# data = part.inline_data.data +# image_bytes = ( +# data if isinstance(data, (bytes, bytearray)) +# else base64.b64decode(data) +# ) +# break +# if image_bytes: +# break - else: - raise RuntimeError(f"Unsupported IMAGE_MODEL: {effective_model}") +# if not image_bytes: +# raise RuntimeError("Gemini did not return an image") +# img = Image.open(io.BytesIO(image_bytes)) +# img.verify() +# return Image.open(io.BytesIO(image_bytes)).convert("RGB") +# else: +# raise RuntimeError(f"Unsupported IMAGE_MODEL: {effective_model}") -def get_admin_db(appname: Optional[str]): - """ - Returns (categories_collection, media_clicks_collection) - based on appname - """ - # ---- Collage Maker ---- - if appname == "collage-maker": - collage_uri = os.getenv("COLLAGE_MAKER_DB_URL") - if not collage_uri: - raise RuntimeError("COLLAGE_MAKER_DB_URL not set") - client = MongoClient(collage_uri) - db = client["adminPanel"] - return db.categories, db.media_clicks - # ---- AI ENHANCER ---- - if appname == "AI-Enhancer": - enhancer_uri = os.getenv("AI_ENHANCER_DB_URL") - if not enhancer_uri: - raise RuntimeError("AI_ENHANCER_DB_URL not set") +# def get_admin_db(appname: Optional[str]): +# """ +# Returns (categories_collection, media_clicks_collection) +# based on appname +# """ +# # ---- Collage Maker ---- +# if appname == "collage-maker": +# collage_uri = os.getenv("COLLAGE_MAKER_DB_URL") +# if not collage_uri: +# raise RuntimeError("COLLAGE_MAKER_DB_URL not set") - client = MongoClient(enhancer_uri) - db = client["test"] - return db.categories, db.media_clicks +# client = MongoClient(collage_uri) +# db = client["adminPanel"] +# return db.categories, db.media_clicks - # DEFAULT (existing behavior) - admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) - db = admin_client["adminPanel"] - return db.categories, db.media_clicks +# # ---- AI ENHANCER ---- +# if appname == "AI-Enhancer": +# enhancer_uri = os.getenv("AI_ENHANCER_DB_URL") +# if not enhancer_uri: +# raise RuntimeError("AI_ENHANCER_DB_URL not set") -# --------------------------------------------------------------------- -# Endpoints -# --------------------------------------------------------------------- -@app.get("/") -async def root(): - """Root endpoint""" - return { - "success": True, - "message": "Polaroid,Kiddo,Makeup,Hairstyle API", - "data": { - "version": "1.0.1", - "Product Name":"Beauty Camera - GlowCam AI Studio", - "Released By" : "LogicGo Infotech" - } - } +# client = MongoClient(enhancer_uri) +# db = client["test"] +# return db.categories, db.media_clicks + +# # DEFAULT (existing behavior) +# admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) +# db = admin_client["adminPanel"] +# return db.categories, db.media_clicks + +# # --------------------------------------------------------------------- +# # Endpoints +# # --------------------------------------------------------------------- +# @app.get("/") +# async def root(): +# """Root endpoint""" +# return { +# "success": True, +# "message": "Polaroid,Kiddo,Makeup,Hairstyle API", +# "data": { +# "version": "1.0.1", +# "Product Name":"Beauty Camera - GlowCam AI Studio", +# "Released By" : "LogicGo Infotech" +# } +# } -@app.get("/health", response_model=HealthResponse) -def health(): - """Public health check""" - mongo.admin.command("ping") - return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit") +# @app.get("/health", response_model=HealthResponse) +# def health(): +# """Public health check""" +# mongo.admin.command("ping") +# return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit") -@app.post("/generate") -async def generate( - prompt: str = Form(...), - image1: UploadFile = File(...), - image2: Optional[UploadFile] = File(None), - user_id: Optional[str] = Form(None), - category_id: Optional[str] = Form(None), - appname: Optional[str] = Form(None), - user=Depends(verify_firebase_token) -): - start_time = time.time() +# @app.post("/generate") +# async def generate( +# prompt: str = Form(...), +# image1: UploadFile = File(...), +# image2: Optional[UploadFile] = File(None), +# user_id: Optional[str] = Form(None), +# category_id: Optional[str] = Form(None), +# appname: Optional[str] = Form(None), +# user=Depends(verify_firebase_token) +# ): +# start_time = time.time() - # ------------------------- - # 1. VALIDATE & READ IMAGES - # ------------------------- - try: - img1_bytes = await image1.read() - pil_img1 = prepare_image(img1_bytes) - input1_id = generate_image_id() - input1_filename = f"{input1_id}_{image1.filename}" - input1_url = upload_to_digitalocean(img1_bytes, "source", input1_filename) - except Exception as e: - raise HTTPException(400, f"Failed to read first image: {e}") +# # ------------------------- +# # 1. VALIDATE & READ IMAGES +# # ------------------------- +# try: +# img1_bytes = await image1.read() +# pil_img1 = prepare_image(img1_bytes) +# input1_id = generate_image_id() +# input1_filename = f"{input1_id}_{image1.filename}" +# input1_url = upload_to_digitalocean(img1_bytes, "source", input1_filename) +# except Exception as e: +# raise HTTPException(400, f"Failed to read first image: {e}") - img2_bytes = None - input2_id = None - input2_url = None - pil_img2 = None +# img2_bytes = None +# input2_id = None +# input2_url = None +# pil_img2 = None - if image2: - try: - img2_bytes = await image2.read() - pil_img2 = prepare_image(img2_bytes) - input2_id = generate_image_id() - input2_filename = f"{input2_id}_{image2.filename}" - input2_url = upload_to_digitalocean(img2_bytes, "source", input2_filename) - except Exception as e: - raise HTTPException(400, f"Failed to read second image: {e}") +# if image2: +# try: +# img2_bytes = await image2.read() +# pil_img2 = prepare_image(img2_bytes) +# input2_id = generate_image_id() +# input2_filename = f"{input2_id}_{image2.filename}" +# input2_url = upload_to_digitalocean(img2_bytes, "source", input2_filename) +# except Exception as e: +# raise HTTPException(400, f"Failed to read second image: {e}") - # ------------------------- - # 3. CATEGORY CLICK LOGIC - # ------------------------- - if user_id and category_id: - try: - admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) - admin_db = admin_client["adminPanel"] +# # ------------------------- +# # 3. CATEGORY CLICK LOGIC +# # ------------------------- +# if user_id and category_id: +# try: +# admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) +# admin_db = admin_client["adminPanel"] - categories_col, media_clicks_col = get_admin_db(appname) - # categories_col = admin_db.categories - # media_clicks_col = admin_db.media_clicks - - # Validate user_oid & category_oid - user_oid = ObjectId(user_id) - category_oid = ObjectId(category_id) +# categories_col, media_clicks_col = get_admin_db(appname) +# # categories_col = admin_db.categories +# # media_clicks_col = admin_db.media_clicks - # Check category exists - category_doc = categories_col.find_one({"_id": category_oid}) - if not category_doc: - raise HTTPException(400, f"Invalid category_id: {category_id}") +# # Validate user_oid & category_oid +# user_oid = ObjectId(user_id) +# category_oid = ObjectId(category_id) - now = datetime.utcnow() +# # Check category exists +# category_doc = categories_col.find_one({"_id": category_oid}) +# if not category_doc: +# raise HTTPException(400, f"Invalid category_id: {category_id}") - # Normalize dates (UTC midnight) - today_date = datetime(now.year, now.month, now.day) - yesterday_date = today_date - timedelta(days=1) +# now = datetime.utcnow() - # -------------------------------------------------- - # AI EDIT USAGE TRACKING (GLOBAL PER USER) - # -------------------------------------------------- - media_clicks_col.update_one( - {"userId": user_oid}, - { - "$setOnInsert": { - "createdAt": now, - "ai_edit_daily_count": [] - }, - "$set": { - "ai_edit_last_date": now, - "updatedAt": now - }, - "$inc": { - "ai_edit_complete": 1 - } - }, - upsert=True - ) +# # Normalize dates (UTC midnight) +# today_date = datetime(now.year, now.month, now.day) +# yesterday_date = today_date - timedelta(days=1) - # -------------------------------------------------- - # DAILY COUNT LOGIC - # -------------------------------------------------- - now = datetime.utcnow() - today_date = datetime(now.year, now.month, now.day) - - doc = media_clicks_col.find_one( - {"userId": user_oid}, - {"ai_edit_daily_count": 1} - ) +# # -------------------------------------------------- +# # AI EDIT USAGE TRACKING (GLOBAL PER USER) +# # -------------------------------------------------- +# media_clicks_col.update_one( +# {"userId": user_oid}, +# { +# "$setOnInsert": { +# "createdAt": now, +# "ai_edit_daily_count": [] +# }, +# "$set": { +# "ai_edit_last_date": now, +# "updatedAt": now +# }, +# "$inc": { +# "ai_edit_complete": 1 +# } +# }, +# upsert=True +# ) + +# # -------------------------------------------------- +# # DAILY COUNT LOGIC +# # -------------------------------------------------- +# now = datetime.utcnow() +# today_date = datetime(now.year, now.month, now.day) - daily_entries = doc.get("ai_edit_daily_count", []) if doc else [] +# doc = media_clicks_col.find_one( +# {"userId": user_oid}, +# {"ai_edit_daily_count": 1} +# ) - # Build UNIQUE date -> count map - daily_map = {} - for entry in daily_entries: - d = entry["date"] - d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d - daily_map[d] = entry["count"] # overwrite = no duplicates +# daily_entries = doc.get("ai_edit_daily_count", []) if doc else [] - # Find last known date - last_date = max(daily_map.keys()) if daily_map else today_date +# # Build UNIQUE date -> count map +# daily_map = {} +# for entry in daily_entries: +# d = entry["date"] +# d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d +# daily_map[d] = entry["count"] # overwrite = no duplicates - # Fill ALL missing days with 0 - next_day = last_date + timedelta(days=1) - while next_day < today_date: - daily_map.setdefault(next_day, 0) - next_day += timedelta(days=1) +# # Find last known date +# last_date = max(daily_map.keys()) if daily_map else today_date - # Mark today as used (binary) - daily_map[today_date] = 1 +# # Fill ALL missing days with 0 +# next_day = last_date + timedelta(days=1) +# while next_day < today_date: +# daily_map.setdefault(next_day, 0) +# next_day += timedelta(days=1) - # Rebuild list (OLD → NEW) - final_daily_entries = [ - {"date": d, "count": daily_map[d]} - for d in sorted(daily_map.keys()) - ] +# # Mark today as used (binary) +# daily_map[today_date] = 1 - # Keep last 32 days only - final_daily_entries = final_daily_entries[-32:] +# # Rebuild list (OLD → NEW) +# final_daily_entries = [ +# {"date": d, "count": daily_map[d]} +# for d in sorted(daily_map.keys()) +# ] - # ATOMIC REPLACE (NO PUSH) - media_clicks_col.update_one( - {"userId": user_oid}, - { - "$set": { - "ai_edit_daily_count": final_daily_entries, - "ai_edit_last_date": now, - "updatedAt": now - } - } - ) +# # Keep last 32 days only +# final_daily_entries = final_daily_entries[-32:] + +# # ATOMIC REPLACE (NO PUSH) +# media_clicks_col.update_one( +# {"userId": user_oid}, +# { +# "$set": { +# "ai_edit_daily_count": final_daily_entries, +# "ai_edit_last_date": now, +# "updatedAt": now +# } +# } +# ) - # -------------------------------------------------- - # CATEGORY CLICK LOGIC - # -------------------------------------------------- - update_res = media_clicks_col.update_one( - {"userId": user_oid, "categories.categoryId": category_oid}, - { - "$set": { - "updatedAt": now, - "categories.$.lastClickedAt": now - }, - "$inc": { - "categories.$.click_count": 1 - } - } - ) +# # -------------------------------------------------- +# # CATEGORY CLICK LOGIC +# # -------------------------------------------------- +# update_res = media_clicks_col.update_one( +# {"userId": user_oid, "categories.categoryId": category_oid}, +# { +# "$set": { +# "updatedAt": now, +# "categories.$.lastClickedAt": now +# }, +# "$inc": { +# "categories.$.click_count": 1 +# } +# } +# ) - # If category does not exist → push new - if update_res.matched_count == 0: - media_clicks_col.update_one( - {"userId": user_oid}, - { - "$set": {"updatedAt": now}, - "$push": { - "categories": { - "categoryId": category_oid, - "click_count": 1, - "lastClickedAt": now - } - } - }, - upsert=True - ) +# # If category does not exist → push new +# if update_res.matched_count == 0: +# media_clicks_col.update_one( +# {"userId": user_oid}, +# { +# "$set": {"updatedAt": now}, +# "$push": { +# "categories": { +# "categoryId": category_oid, +# "click_count": 1, +# "lastClickedAt": now +# } +# } +# }, +# upsert=True +# ) - except Exception as e: - print("CATEGORY_LOG_ERROR:", e) - # ------------------------- - # 4. HF INFERENCE - # ------------------------- - try: - # -------------------------------------------------- - # MODEL OVERRIDE BASED ON CATEGORY - # -------------------------------------------------- - force_model = None +# except Exception as e: +# print("CATEGORY_LOG_ERROR:", e) +# # ------------------------- +# # 4. HF INFERENCE +# # ------------------------- +# try: +# # -------------------------------------------------- +# # MODEL OVERRIDE BASED ON CATEGORY +# # -------------------------------------------------- +# force_model = None - if category_id == GEMINI_FORCE_CATEGORY_ID: - force_model = "GEMINI" +# if category_id == GEMINI_FORCE_CATEGORY_ID: +# force_model = "GEMINI" - pil_output = run_image_generation( - image1=pil_img1, - image2=pil_img2, - prompt=prompt, - force_model="GEMINI" if category_id == GEMINI_FORCE_CATEGORY_ID else None - ) +# pil_output = run_image_generation( +# image1=pil_img1, +# image2=pil_img2, +# prompt=prompt, +# force_model="GEMINI" if category_id == GEMINI_FORCE_CATEGORY_ID else None +# ) - except Exception as e: - response_time_ms = round((time.time() - start_time) * 1000) - # logs_collection.insert_one({ - # "timestamp": datetime.utcnow(), - # "status": "failure", - # "input1_id": input1_id, - # "input2_id": input2_id if input2_id else None, - # "prompt": prompt, - # "user_email": user.get("email"), - # "error": str(e), - # "response_time_ms": response_time_ms, - # "appname": appname - # - if logs_collection is not None: - logs_collection.insert_one({ - "endpoint": "/generate", - "status": "fail", - "response_time_ms": float(response_time_ms), - "timestamp": datetime.utcnow(), - "appname": appname if appname else "None", - "error": str(e) - }) - raise HTTPException(500, f"Inference failed: {e}") +# except Exception as e: +# response_time_ms = round((time.time() - start_time) * 1000) +# # logs_collection.insert_one({ +# # "timestamp": datetime.utcnow(), +# # "status": "failure", +# # "input1_id": input1_id, +# # "input2_id": input2_id if input2_id else None, +# # "prompt": prompt, +# # "user_email": user.get("email"), +# # "error": str(e), +# # "response_time_ms": response_time_ms, +# # "appname": appname +# # +# if logs_collection is not None: +# logs_collection.insert_one({ +# "endpoint": "/generate", +# "status": "fail", +# "response_time_ms": float(response_time_ms), +# "timestamp": datetime.utcnow(), +# "appname": appname if appname else "None", +# "error": str(e) +# }) +# raise HTTPException(500, f"Inference failed: {e}") - # ------------------------- - # 5. SAVE OUTPUT IMAGE - # ------------------------- - output_object_id = ObjectId() - output_image_id = str(output_object_id) - out_buf = io.BytesIO() - pil_output.save(out_buf, format="PNG") - out_bytes = out_buf.getvalue() +# # ------------------------- +# # 5. SAVE OUTPUT IMAGE +# # ------------------------- +# output_object_id = ObjectId() +# output_image_id = str(output_object_id) +# out_buf = io.BytesIO() +# pil_output.save(out_buf, format="PNG") +# out_bytes = out_buf.getvalue() - out_filename = f"{output_image_id}_result.png" - out_url = upload_to_digitalocean(out_bytes, "results", out_filename) - try: - fs.put( - out_bytes, - _id=output_object_id, - filename=out_filename, - content_type="image/png", - metadata={ - "user_email": user.get("email"), - "prompt": prompt, - "created_at": datetime.utcnow() - } - ) - except Exception as e: - raise HTTPException(500, f"Failed to store output image in GridFS: {e}") +# out_filename = f"{output_image_id}_result.png" +# out_url = upload_to_digitalocean(out_bytes, "results", out_filename) +# try: +# fs.put( +# out_bytes, +# _id=output_object_id, +# filename=out_filename, +# content_type="image/png", +# metadata={ +# "user_email": user.get("email"), +# "prompt": prompt, +# "created_at": datetime.utcnow() +# } +# ) +# except Exception as e: +# raise HTTPException(500, f"Failed to store output image in GridFS: {e}") - # ------------------------- - # 5b. SAVE COMPRESSED IMAGE - # ------------------------- - compressed_bytes = compress_pil_image_to_2mb( - pil_output, - max_dim=1280 - ) +# # ------------------------- +# # 5b. SAVE COMPRESSED IMAGE +# # ------------------------- +# compressed_bytes = compress_pil_image_to_2mb( +# pil_output, +# max_dim=1280 +# ) - compressed_filename = f"{output_image_id}_compressed.jpg" - compressed_url = upload_to_digitalocean(compressed_bytes, "results", compressed_filename) +# compressed_filename = f"{output_image_id}_compressed.jpg" +# compressed_url = upload_to_digitalocean(compressed_bytes, "results", compressed_filename) - response_time_ms = round((time.time() - start_time) * 1000) +# response_time_ms = round((time.time() - start_time) * 1000) - # ------------------------- - # 6. LOG SUCCESS - # ------------------------- - # logs_collection.insert_one({ - # "timestamp": datetime.utcnow(), - # "status": "success", - # "input1_id": input1_id, - # "input2_id": input2_id if input2_id else None, - # "output_id": output_image_id, - # "prompt": prompt, - # "user_email": user.get("email"), - # "response_time_ms": response_time_ms, - # "appname": appname - # }) - if logs_collection is not None: - logs_collection.insert_one({ - "endpoint": "/generate", - "status": "success", - "response_time_ms": float(response_time_ms), - "timestamp": datetime.utcnow(), - "appname": appname if appname else "None", - "error": None - }) - return JSONResponse({ - "output_image_id": output_image_id, - "user": user.get("email"), - "response_time_ms": response_time_ms, - "Compressed_Image_URL": compressed_url - }) - - -# Image endpoint removed - images are now stored directly in DigitalOcean Spaces -# and are publicly accessible via the Compressed_Image_URL returned in the /generate response -@app.get("/image/{image_id}") -async def download_image( - image_id: str -): - try: - file_obj_id = ObjectId(image_id) - except Exception: - raise HTTPException(status_code=400, detail="Invalid image id format") - - def _read_from_gridfs(): - grid_out = fs.get(file_obj_id) - return grid_out.read(), grid_out.filename, getattr(grid_out, "content_type", "image/png") - - try: - file_bytes, filename, content_type = await run_in_threadpool(_read_from_gridfs) - except NoFile: - raise HTTPException(status_code=404, detail="Image not found") - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to download image: {e}") - - return StreamingResponse( - io.BytesIO(file_bytes), - media_type=content_type or "image/png", - headers={"Content-Disposition": f'attachment; filename="{filename or f"{image_id}.png"}"'} - ) +# # ------------------------- +# # 6. LOG SUCCESS +# # ------------------------- +# # logs_collection.insert_one({ +# # "timestamp": datetime.utcnow(), +# # "status": "success", +# # "input1_id": input1_id, +# # "input2_id": input2_id if input2_id else None, +# # "output_id": output_image_id, +# # "prompt": prompt, +# # "user_email": user.get("email"), +# # "response_time_ms": response_time_ms, +# # "appname": appname +# # }) +# if logs_collection is not None: +# logs_collection.insert_one({ +# "endpoint": "/generate", +# "status": "success", +# "response_time_ms": float(response_time_ms), +# "timestamp": datetime.utcnow(), +# "appname": appname if appname else "None", +# "error": None +# }) +# return JSONResponse({ +# "output_image_id": output_image_id, +# "user": user.get("email"), +# "response_time_ms": response_time_ms, +# "Compressed_Image_URL": compressed_url +# }) -# --------------------------------------------------------------------- -# Run locally -# --------------------------------------------------------------------- -if __name__ == "__main__": - import uvicorn - uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) +# # Image endpoint removed - images are now stored directly in DigitalOcean Spaces +# # and are publicly accessible via the Compressed_Image_URL returned in the /generate response +# @app.get("/image/{image_id}") +# async def download_image( +# image_id: str +# ): +# try: +# file_obj_id = ObjectId(image_id) +# except Exception: +# raise HTTPException(status_code=400, detail="Invalid image id format") +# def _read_from_gridfs(): +# grid_out = fs.get(file_obj_id) +# return grid_out.read(), grid_out.filename, getattr(grid_out, "content_type", "image/png") +# try: +# file_bytes, filename, content_type = await run_in_threadpool(_read_from_gridfs) +# except NoFile: +# raise HTTPException(status_code=404, detail="Image not found") +# except Exception as e: +# raise HTTPException(status_code=500, detail=f"Failed to download image: {e}") +# return StreamingResponse( +# io.BytesIO(file_bytes), +# media_type=content_type or "image/png", +# headers={"Content-Disposition": f'attachment; filename="{filename or f"{image_id}.png"}"'} +# ) -# import os -# import io -# import json -# import traceback -# from datetime import datetime,timedelta -# from typing import Optional -# import time -# import uuid -# import boto3 -# from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends -# from fastapi.responses import StreamingResponse, JSONResponse -# from fastapi.middleware.cors import CORSMiddleware -# from pydantic import BaseModel -# from pymongo import MongoClient -# import gridfs -# from bson.objectid import ObjectId -# from PIL import Image -# from fastapi.concurrency import run_in_threadpool -# import shutil -# import firebase_admin -# from firebase_admin import credentials, auth -# from PIL import Image -# from huggingface_hub import InferenceClient # # --------------------------------------------------------------------- -# # Load Firebase Config from env (stringified JSON) +# # Run locally # # --------------------------------------------------------------------- -# firebase_config_json = os.getenv("firebase_config") -# if not firebase_config_json: -# raise RuntimeError("❌ Missing Firebase config in environment variable 'firebase_config'") +# if __name__ == "__main__": +# import uvicorn +# uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) -# try: -# firebase_creds_dict = json.loads(firebase_config_json) -# cred = credentials.Certificate(firebase_creds_dict) -# firebase_admin.initialize_app(cred) -# except Exception as e: -# raise RuntimeError(f"Failed to initialize Firebase Admin SDK: {e}") -# # --------------------------------------------------------------------- -# # Hugging Face setup -# # --------------------------------------------------------------------- -# HF_TOKEN = os.getenv("HF_TOKEN") -# if not HF_TOKEN: -# raise RuntimeError("HF_TOKEN not set in environment variables") -# hf_client = InferenceClient(token=HF_TOKEN) -# # --------------------------------------------------------------------- -# # MODEL SELECTION -# # --------------------------------------------------------------------- -# genai = None # ✅ IMPORTANT: module-level declaration -# MODEL = os.getenv("IMAGE_MODEL", "GEMINI").upper() -# GEMINI_FORCE_CATEGORY_ID = "69368e741224bcb6bdb98076" -# # --------------------------------------------------------------------- -# # Gemini setup (ONLY used if MODEL == "GEMINI") -# # --------------------------------------------------------------------- -# GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") -# GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image") -# # --------------------------------------------------------------------- -# # MongoDB setup -# # --------------------------------------------------------------------- -# MONGODB_URI=os.getenv("MONGODB_URI") -# DB_NAME = "polaroid_db" -# mongo = MongoClient(MONGODB_URI) -# db = mongo[DB_NAME] -# fs = gridfs.GridFS(db) -# logs_collection = db["logs"] +import os +import io +import json +import traceback +from datetime import datetime,timedelta +from typing import Optional +import time +import uuid +import boto3 +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from pymongo import MongoClient +import gridfs +from bson.objectid import ObjectId +from PIL import Image +from fastapi.concurrency import run_in_threadpool +import shutil +import firebase_admin +from firebase_admin import credentials, auth +from PIL import Image +from huggingface_hub import InferenceClient +# --------------------------------------------------------------------- +# Load Firebase Config from env (stringified JSON) +# --------------------------------------------------------------------- +firebase_config_json = os.getenv("firebase_config") +if not firebase_config_json: + raise RuntimeError("❌ Missing Firebase config in environment variable 'firebase_config'") -# # --------------------------------------------------------------------- -# # DigitalOcean Spaces setup -# # --------------------------------------------------------------------- -# DO_SPACES_KEY = os.getenv("DO_SPACES_KEY") -# DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET") -# DO_SPACES_REGION = os.getenv("DO_SPACES_REGION", "blr1") -# DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET") -# DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT", f"https://{DO_SPACES_REGION}.digitaloceanspaces.com") +try: + firebase_creds_dict = json.loads(firebase_config_json) + cred = credentials.Certificate(firebase_creds_dict) + firebase_admin.initialize_app(cred) +except Exception as e: + raise RuntimeError(f"Failed to initialize Firebase Admin SDK: {e}") -# if not DO_SPACES_KEY or not DO_SPACES_SECRET: -# raise RuntimeError("Missing DigitalOcean Spaces credentials in environment variables") +# --------------------------------------------------------------------- +# Hugging Face setup +# --------------------------------------------------------------------- +HF_TOKEN = os.getenv("HF_TOKEN") +if not HF_TOKEN: + raise RuntimeError("HF_TOKEN not set in environment variables") -# # Initialize S3 client for DigitalOcean Spaces -# s3_client = boto3.client( -# 's3', -# region_name=DO_SPACES_REGION, -# endpoint_url=DO_SPACES_ENDPOINT, -# aws_access_key_id=DO_SPACES_KEY, -# aws_secret_access_key=DO_SPACES_SECRET -# ) +hf_client = InferenceClient(token=HF_TOKEN) +# --------------------------------------------------------------------- +# MODEL SELECTION +# --------------------------------------------------------------------- +genai = None # ✅ IMPORTANT: module-level declaration +MODEL = os.getenv("IMAGE_MODEL", "GEMINI").upper() +GEMINI_FORCE_CATEGORY_ID = "69368e741224bcb6bdb98076" +# --------------------------------------------------------------------- +# Gemini setup (ONLY used if MODEL == "GEMINI") +# --------------------------------------------------------------------- +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") +GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image") -# # --------------------------------------------------------------------- -# # FastAPI app setup -# # --------------------------------------------------------------------- -# app = FastAPI(title="Qwen Image Edit API with Firebase Auth") -# app.add_middleware( -# CORSMiddleware, -# allow_origins=["*"], -# allow_credentials=True, -# allow_methods=["*"], -# allow_headers=["*"], -# ) +# --------------------------------------------------------------------- +# MongoDB setup +# --------------------------------------------------------------------- +MONGODB_URI=os.getenv("MONGODB_URI") +DB_NAME = "polaroid_db" -# # --------------------------------------------------------------------- -# # Auth dependency -# # --------------------------------------------------------------------- -# async def verify_firebase_token(request: Request): -# """Middleware-like dependency to verify Firebase JWT from Authorization header.""" -# auth_header = request.headers.get("Authorization") -# if not auth_header or not auth_header.startswith("Bearer "): -# raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") +mongo = MongoClient(MONGODB_URI) +db = mongo[DB_NAME] +fs = gridfs.GridFS(db) +logs_collection = db["logs"] -# id_token = auth_header.split("Bearer ")[1] -# try: -# decoded_token = auth.verify_id_token(id_token) -# request.state.user = decoded_token -# return decoded_token -# except Exception as e: -# raise HTTPException(status_code=401, detail=f"Invalid or expired Firebase token: {e}") +# --------------------------------------------------------------------- +# DigitalOcean Spaces setup +# --------------------------------------------------------------------- +DO_SPACES_KEY = os.getenv("DO_SPACES_KEY") +DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET") +DO_SPACES_REGION = os.getenv("DO_SPACES_REGION", "blr1") +DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET") +DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT", f"https://{DO_SPACES_REGION}.digitaloceanspaces.com") -# # --------------------------------------------------------------------- -# # Models -# # --------------------------------------------------------------------- -# class HealthResponse(BaseModel): -# status: str -# db: str -# model: str +if not DO_SPACES_KEY or not DO_SPACES_SECRET: + raise RuntimeError("Missing DigitalOcean Spaces credentials in environment variables") -# # --------------------- UTILS --------------------- -# def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image: -# """ -# Resize image to fit within max_size while keeping aspect ratio. -# """ -# if img.width > max_size[0] or img.height > max_size[1]: -# img.thumbnail(max_size, Image.ANTIALIAS) -# return img -# # --------------------------------------------------------------------- -# # Lazy Gemini Initialization -# # --------------------------------------------------------------------- -# _genai_initialized = False -# def init_gemini(): -# global _genai_initialized, genai +# Initialize S3 client for DigitalOcean Spaces +s3_client = boto3.client( + 's3', + region_name=DO_SPACES_REGION, + endpoint_url=DO_SPACES_ENDPOINT, + aws_access_key_id=DO_SPACES_KEY, + aws_secret_access_key=DO_SPACES_SECRET +) -# if _genai_initialized: -# return +# --------------------------------------------------------------------- +# FastAPI app setup +# --------------------------------------------------------------------- +app = FastAPI(title="Qwen Image Edit API with Firebase Auth") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) -# if not GEMINI_API_KEY: -# raise RuntimeError("❌ GEMINI_API_KEY not set") +# --------------------------------------------------------------------- +# Auth dependency +# --------------------------------------------------------------------- +async def verify_firebase_token(request: Request): + """Middleware-like dependency to verify Firebase JWT from Authorization header.""" + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") -# import google.generativeai as genai -# genai.configure(api_key=GEMINI_API_KEY) + id_token = auth_header.split("Bearer ")[1] + try: + decoded_token = auth.verify_id_token(id_token) + request.state.user = decoded_token + return decoded_token + except Exception as e: + raise HTTPException(status_code=401, detail=f"Invalid or expired Firebase token: {e}") -# _genai_initialized = True +# --------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------- +class HealthResponse(BaseModel): + status: str + db: str + model: str + +# --------------------- UTILS --------------------- +def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image: + """ + Resize image to fit within max_size while keeping aspect ratio. + """ + if img.width > max_size[0] or img.height > max_size[1]: + img.thumbnail(max_size, Image.ANTIALIAS) + return img +# --------------------------------------------------------------------- +# Lazy Gemini Initialization +# --------------------------------------------------------------------- +_genai_initialized = False +def init_gemini(): + global _genai_initialized, genai + + if _genai_initialized: + return + + if not GEMINI_API_KEY: + raise RuntimeError("❌ GEMINI_API_KEY not set") + + import google.generativeai as genai + genai.configure(api_key=GEMINI_API_KEY) + + _genai_initialized = True -# def expiry(hours: int): -# """Return UTC datetime for TTL expiration""" -# return datetime.utcnow() + timedelta(hours=hours) +def expiry(hours: int): + """Return UTC datetime for TTL expiration""" + return datetime.utcnow() + timedelta(hours=hours) -# def prepare_image(file_bytes: bytes) -> Image.Image: -# """ -# Open image and resize if larger than 1024x1024 -# """ -# img = Image.open(io.BytesIO(file_bytes)).convert("RGB") +def prepare_image(file_bytes: bytes) -> Image.Image: + """ + Open image and resize if larger than 1024x1024 + """ + img = Image.open(io.BytesIO(file_bytes)).convert("RGB") -# # ✅ MIN SIZE CHECK -# if img.width < 200 or img.height < 200: -# raise HTTPException( -# status_code=400, -# detail="Image size is below 200x200 pixels. Please upload a larger image." -# ) -# img = Image.open(io.BytesIO(file_bytes)).convert("RGB") -# img = resize_image_if_needed(img, max_size=(1024, 1024)) -# return img + # ✅ MIN SIZE CHECK + if img.width < 200 or img.height < 200: + raise HTTPException( + status_code=400, + detail="Image size is below 200x200 pixels. Please upload a larger image." + ) + img = Image.open(io.BytesIO(file_bytes)).convert("RGB") + img = resize_image_if_needed(img, max_size=(1024, 1024)) + return img -# def upload_to_digitalocean(image_bytes: bytes, folder: str, filename: str) -> str: -# """ -# Upload image to DigitalOcean Spaces and return the public URL. -# folder: 'source' or 'results' -# """ -# key = f"valentine/{folder}/{filename}" -# try: -# s3_client.put_object( -# Bucket=DO_SPACES_BUCKET, -# Key=key, -# Body=image_bytes, -# ContentType="image/jpeg" if filename.endswith('.jpg') else "image/png", -# ACL='public-read' -# ) -# url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{key}" -# return url -# except Exception as e: -# raise RuntimeError(f"Failed to upload to DigitalOcean Spaces: {e}") +def upload_to_digitalocean(image_bytes: bytes, folder: str, filename: str) -> str: + """ + Upload image to DigitalOcean Spaces and return the public URL. + folder: 'source' or 'results' + """ + key = f"valentine/{folder}/{filename}" + try: + s3_client.put_object( + Bucket=DO_SPACES_BUCKET, + Key=key, + Body=image_bytes, + ContentType="image/jpeg" if filename.endswith('.jpg') else "image/png", + ACL='public-read' + ) + url = f"{DO_SPACES_ENDPOINT}/{DO_SPACES_BUCKET}/{key}" + return url + except Exception as e: + raise RuntimeError(f"Failed to upload to DigitalOcean Spaces: {e}") -# def generate_image_id() -> str: -# """ -# Generate a random image ID (same format as MongoDB ObjectId for consistency) -# """ -# return str(uuid.uuid4().hex[:24]) +def generate_image_id() -> str: + """ + Generate a random image ID (same format as MongoDB ObjectId for consistency) + """ + return str(uuid.uuid4().hex[:24]) -# MAX_COMPRESSED_SIZE = 2 * 1024 * 1024 # 2 MB -# def compress_pil_image_to_2mb( -# pil_img: Image.Image, -# max_dim: int = 1280 -# ) -> bytes: -# """ -# Resize + compress PIL image to <= 2MB. -# Returns JPEG bytes. -# """ -# img = pil_img.convert("RGB") +MAX_COMPRESSED_SIZE = 2 * 1024 * 1024 # 2 MB +def compress_pil_image_to_2mb( + pil_img: Image.Image, + max_dim: int = 1280 +) -> bytes: + """ + Resize + compress PIL image to <= 2MB. + Returns JPEG bytes. + """ + img = pil_img.convert("RGB") -# # Resize (maintain aspect ratio) -# img.thumbnail((max_dim, max_dim), Image.LANCZOS) + # Resize (maintain aspect ratio) + img.thumbnail((max_dim, max_dim), Image.LANCZOS) -# quality = 85 -# buffer = io.BytesIO() + quality = 85 + buffer = io.BytesIO() -# while quality >= 40: -# buffer.seek(0) -# buffer.truncate() + while quality >= 40: + buffer.seek(0) + buffer.truncate() -# img.save( -# buffer, -# format="JPEG", -# quality=quality, -# optimize=True, -# progressive=True -# ) + img.save( + buffer, + format="JPEG", + quality=quality, + optimize=True, + progressive=True + ) -# if buffer.tell() <= MAX_COMPRESSED_SIZE: -# break + if buffer.tell() <= MAX_COMPRESSED_SIZE: + break -# quality -= 5 + quality -= 5 -# return buffer.getvalue() + return buffer.getvalue() -# def run_image_generation( -# image1: Image.Image, -# prompt: str, -# image2: Optional[Image.Image] = None, -# force_model: Optional[str] = None -# ) -> Image.Image: -# """ -# Unified image generation interface. -# QWEN -> merges images if image2 exists -# GEMINI -> passes images separately -# """ +def run_image_generation( + image1: Image.Image, + prompt: str, + image2: Optional[Image.Image] = None, + force_model: Optional[str] = None +) -> Image.Image: + """ + Unified image generation interface. + QWEN -> merges images if image2 exists + GEMINI -> passes images separately + """ -# effective_model = (force_model or MODEL).upper() + effective_model = (force_model or MODEL).upper() -# # ---------------- QWEN ---------------- -# if effective_model == "QWEN": + # ---------------- QWEN ---------------- + if effective_model == "QWEN": -# # ✅ Merge images ONLY for QWEN -# if image2: -# total_width = image1.width + image2.width -# max_height = max(image1.height, image2.height) -# merged = Image.new("RGB", (total_width, max_height)) -# merged.paste(image1, (0, 0)) -# merged.paste(image2, (image1.width, 0)) -# else: -# merged = image1 + # ✅ Merge images ONLY for QWEN + if image2: + total_width = image1.width + image2.width + max_height = max(image1.height, image2.height) + merged = Image.new("RGB", (total_width, max_height)) + merged.paste(image1, (0, 0)) + merged.paste(image2, (image1.width, 0)) + else: + merged = image1 -# return hf_client.image_to_image( -# image=merged, -# prompt=prompt, -# model="Qwen/Qwen-Image-Edit" -# ) + return hf_client.image_to_image( + image=merged, + prompt=prompt, + model="Qwen/Qwen-Image-Edit" + ) -# # ---------------- GEMINI ---------------- -# elif effective_model == "GEMINI": -# init_gemini() -# model = genai.GenerativeModel(GEMINI_IMAGE_MODEL) + # ---------------- GEMINI ---------------- + elif effective_model == "GEMINI": + init_gemini() + model = genai.GenerativeModel(GEMINI_IMAGE_MODEL) -# parts = [prompt] + parts = [prompt] -# def add_image(img: Image.Image): -# buf = io.BytesIO() -# img.save(buf, format="PNG") -# buf.seek(0) -# parts.append({ -# "mime_type": "image/png", -# "data": buf.getvalue() -# }) + def add_image(img: Image.Image): + buf = io.BytesIO() + img.save(buf, format="PNG") + buf.seek(0) + parts.append({ + "mime_type": "image/png", + "data": buf.getvalue() + }) -# add_image(image1) + add_image(image1) -# if image2: -# add_image(image2) + if image2: + add_image(image2) -# response = model.generate_content(parts) + response = model.generate_content(parts) -# # -------- SAFE IMAGE EXTRACTION -------- -# import base64 + # -------- SAFE IMAGE EXTRACTION -------- + import base64 -# image_bytes = None -# for candidate in response.candidates: -# for part in candidate.content.parts: -# if hasattr(part, "inline_data") and part.inline_data: -# data = part.inline_data.data -# image_bytes = ( -# data if isinstance(data, (bytes, bytearray)) -# else base64.b64decode(data) -# ) -# break -# if image_bytes: -# break + image_bytes = None + for candidate in response.candidates: + for part in candidate.content.parts: + if hasattr(part, "inline_data") and part.inline_data: + data = part.inline_data.data + image_bytes = ( + data if isinstance(data, (bytes, bytearray)) + else base64.b64decode(data) + ) + break + if image_bytes: + break -# if not image_bytes: -# raise RuntimeError("Gemini did not return an image") + if not image_bytes: + raise RuntimeError("Gemini did not return an image") -# img = Image.open(io.BytesIO(image_bytes)) -# img.verify() -# return Image.open(io.BytesIO(image_bytes)).convert("RGB") + img = Image.open(io.BytesIO(image_bytes)) + img.verify() + return Image.open(io.BytesIO(image_bytes)).convert("RGB") -# else: -# raise RuntimeError(f"Unsupported IMAGE_MODEL: {effective_model}") + else: + raise RuntimeError(f"Unsupported IMAGE_MODEL: {effective_model}") -# def get_admin_db(appname: Optional[str]): -# """ -# Returns (categories_collection, media_clicks_collection) -# based on appname -# """ -# # ---- Collage Maker ---- -# if appname == "collage-maker": -# collage_uri = os.getenv("COLLAGE_MAKER_DB_URL") -# if not collage_uri: -# raise RuntimeError("COLLAGE_MAKER_DB_URL not set") +def get_admin_db(appname: Optional[str]): + """ + Returns (categories_collection, media_clicks_collection) + based on appname + """ + # ---- Collage Maker ---- + if appname == "collage-maker": + collage_uri = os.getenv("COLLAGE_MAKER_DB_URL") + if not collage_uri: + raise RuntimeError("COLLAGE_MAKER_DB_URL not set") -# client = MongoClient(collage_uri) -# db = client["adminPanel"] -# return db.categories, db.media_clicks + client = MongoClient(collage_uri) + db = client["adminPanel"] + return db.categories, db.media_clicks -# # ---- AI ENHANCER ---- -# if appname == "AI-Enhancer": -# enhancer_uri = os.getenv("AI_ENHANCER_DB_URL") -# if not enhancer_uri: -# raise RuntimeError("AI_ENHANCER_DB_URL not set") + # ---- AI ENHANCER ---- + if appname == "AI-Enhancer": + enhancer_uri = os.getenv("AI_ENHANCER_DB_URL") + if not enhancer_uri: + raise RuntimeError("AI_ENHANCER_DB_URL not set") -# client = MongoClient(enhancer_uri) -# db = client["test"] -# return db.categories, db.media_clicks + client = MongoClient(enhancer_uri) + db = client["test"] + return db.categories, db.media_clicks -# # DEFAULT (existing behavior) -# admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) -# db = admin_client["adminPanel"] -# return db.categories, db.media_clicks + # DEFAULT (existing behavior) + admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) + db = admin_client["adminPanel"] + return db.categories, db.media_clicks -# # --------------------------------------------------------------------- -# # Endpoints -# # --------------------------------------------------------------------- -# @app.get("/") -# async def root(): -# """Root endpoint""" -# return { -# "success": True, -# "message": "Polaroid,Kiddo,Makeup,Hairstyle API", -# "data": { -# "version": "1.0.1", -# "Product Name":"Beauty Camera - GlowCam AI Studio", -# "Released By" : "LogicGo Infotech" -# } -# } +# --------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------- +@app.get("/") +async def root(): + """Root endpoint""" + return { + "success": True, + "message": "Polaroid,Kiddo,Makeup,Hairstyle API", + "data": { + "version": "1.0.1", + "Product Name":"Beauty Camera - GlowCam AI Studio", + "Released By" : "LogicGo Infotech" + } + } -# @app.get("/health", response_model=HealthResponse) -# def health(): -# """Public health check""" -# mongo.admin.command("ping") -# return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit") +@app.get("/health", response_model=HealthResponse) +def health(): + """Public health check""" + mongo.admin.command("ping") + return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit") -# @app.post("/generate") -# async def generate( -# prompt: str = Form(...), -# image1: UploadFile = File(...), -# image2: Optional[UploadFile] = File(None), -# user_id: Optional[str] = Form(None), -# category_id: Optional[str] = Form(None), -# appname: Optional[str] = Form(None), -# user=Depends(verify_firebase_token) -# ): -# start_time = time.time() +@app.post("/generate") +async def generate( + prompt: str = Form(...), + image1: UploadFile = File(...), + image2: Optional[UploadFile] = File(None), + user_id: Optional[str] = Form(None), + category_id: Optional[str] = Form(None), + appname: Optional[str] = Form(None), + user=Depends(verify_firebase_token) +): + start_time = time.time() -# # ------------------------- -# # 1. VALIDATE & READ IMAGES -# # ------------------------- -# try: -# img1_bytes = await image1.read() -# pil_img1 = prepare_image(img1_bytes) -# input1_id = generate_image_id() -# input1_filename = f"{input1_id}_{image1.filename}" -# input1_url = upload_to_digitalocean(img1_bytes, "source", input1_filename) -# except Exception as e: -# raise HTTPException(400, f"Failed to read first image: {e}") + # ------------------------- + # 1. VALIDATE & READ IMAGES + # ------------------------- + try: + img1_bytes = await image1.read() + pil_img1 = prepare_image(img1_bytes) + input1_id = generate_image_id() + input1_filename = f"{input1_id}_{image1.filename}" + input1_url = upload_to_digitalocean(img1_bytes, "source", input1_filename) + except Exception as e: + raise HTTPException(400, f"Failed to read first image: {e}") -# img2_bytes = None -# input2_id = None -# input2_url = None -# pil_img2 = None + img2_bytes = None + input2_id = None + input2_url = None + pil_img2 = None -# if image2: -# try: -# img2_bytes = await image2.read() -# pil_img2 = prepare_image(img2_bytes) -# input2_id = generate_image_id() -# input2_filename = f"{input2_id}_{image2.filename}" -# input2_url = upload_to_digitalocean(img2_bytes, "source", input2_filename) -# except Exception as e: -# raise HTTPException(400, f"Failed to read second image: {e}") + if image2: + try: + img2_bytes = await image2.read() + pil_img2 = prepare_image(img2_bytes) + input2_id = generate_image_id() + input2_filename = f"{input2_id}_{image2.filename}" + input2_url = upload_to_digitalocean(img2_bytes, "source", input2_filename) + except Exception as e: + raise HTTPException(400, f"Failed to read second image: {e}") -# # ------------------------- -# # 3. CATEGORY CLICK LOGIC -# # ------------------------- -# if user_id and category_id: -# try: -# admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) -# admin_db = admin_client["adminPanel"] + # ------------------------- + # 3. CATEGORY CLICK LOGIC + # ------------------------- + if user_id and category_id: + try: + admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) + admin_db = admin_client["adminPanel"] -# categories_col, media_clicks_col = get_admin_db(appname) -# # categories_col = admin_db.categories -# # media_clicks_col = admin_db.media_clicks + categories_col, media_clicks_col = get_admin_db(appname) + # categories_col = admin_db.categories + # media_clicks_col = admin_db.media_clicks -# # Validate user_oid & category_oid -# user_oid = ObjectId(user_id) -# category_oid = ObjectId(category_id) + # Validate user_oid & category_oid + user_oid = ObjectId(user_id) + category_oid = ObjectId(category_id) -# # Check category exists -# category_doc = categories_col.find_one({"_id": category_oid}) -# if not category_doc: -# raise HTTPException(400, f"Invalid category_id: {category_id}") + # Check category exists + category_doc = categories_col.find_one({"_id": category_oid}) + if not category_doc: + raise HTTPException(400, f"Invalid category_id: {category_id}") -# now = datetime.utcnow() + now = datetime.utcnow() -# # Normalize dates (UTC midnight) -# today_date = datetime(now.year, now.month, now.day) -# yesterday_date = today_date - timedelta(days=1) + # Normalize dates (UTC midnight) + today_date = datetime(now.year, now.month, now.day) + yesterday_date = today_date - timedelta(days=1) -# # -------------------------------------------------- -# # AI EDIT USAGE TRACKING (GLOBAL PER USER) -# # -------------------------------------------------- -# media_clicks_col.update_one( -# {"userId": user_oid}, -# { -# "$setOnInsert": { -# "createdAt": now, -# "ai_edit_daily_count": [] -# }, -# "$set": { -# "ai_edit_last_date": now, -# "updatedAt": now -# }, -# "$inc": { -# "ai_edit_complete": 1 -# } -# }, -# upsert=True -# ) + # -------------------------------------------------- + # AI EDIT USAGE TRACKING (GLOBAL PER USER) + # -------------------------------------------------- + media_clicks_col.update_one( + {"userId": user_oid}, + { + "$setOnInsert": { + "createdAt": now, + "ai_edit_daily_count": [] + }, + "$set": { + "ai_edit_last_date": now, + "updatedAt": now + }, + "$inc": { + "ai_edit_complete": 1 + } + }, + upsert=True + ) -# # -------------------------------------------------- -# # DAILY COUNT LOGIC -# # -------------------------------------------------- -# now = datetime.utcnow() -# today_date = datetime(now.year, now.month, now.day) + # -------------------------------------------------- + # DAILY COUNT LOGIC + # -------------------------------------------------- + now = datetime.utcnow() + today_date = datetime(now.year, now.month, now.day) -# doc = media_clicks_col.find_one( -# {"userId": user_oid}, -# {"ai_edit_daily_count": 1} -# ) + doc = media_clicks_col.find_one( + {"userId": user_oid}, + {"ai_edit_daily_count": 1} + ) -# daily_entries = doc.get("ai_edit_daily_count", []) if doc else [] + daily_entries = doc.get("ai_edit_daily_count", []) if doc else [] -# # Build UNIQUE date -> count map -# daily_map = {} -# for entry in daily_entries: -# d = entry["date"] -# d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d -# daily_map[d] = entry["count"] # overwrite = no duplicates + # Build UNIQUE date -> count map + daily_map = {} + for entry in daily_entries: + d = entry["date"] + d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d + daily_map[d] = entry["count"] # overwrite = no duplicates -# # Find last known date -# last_date = max(daily_map.keys()) if daily_map else today_date + # Find last known date + last_date = max(daily_map.keys()) if daily_map else today_date -# # Fill ALL missing days with 0 -# next_day = last_date + timedelta(days=1) -# while next_day < today_date: -# daily_map.setdefault(next_day, 0) -# next_day += timedelta(days=1) + # Fill ALL missing days with 0 + next_day = last_date + timedelta(days=1) + while next_day < today_date: + daily_map.setdefault(next_day, 0) + next_day += timedelta(days=1) -# # Mark today as used (binary) -# daily_map[today_date] = 1 + # Mark today as used (binary) + daily_map[today_date] = 1 -# # Rebuild list (OLD → NEW) -# final_daily_entries = [ -# {"date": d, "count": daily_map[d]} -# for d in sorted(daily_map.keys()) -# ] + # Rebuild list (OLD → NEW) + final_daily_entries = [ + {"date": d, "count": daily_map[d]} + for d in sorted(daily_map.keys()) + ] -# # Keep last 32 days only -# final_daily_entries = final_daily_entries[-32:] + # Keep last 32 days only + final_daily_entries = final_daily_entries[-32:] -# # ATOMIC REPLACE (NO PUSH) -# media_clicks_col.update_one( -# {"userId": user_oid}, -# { -# "$set": { -# "ai_edit_daily_count": final_daily_entries, -# "ai_edit_last_date": now, -# "updatedAt": now -# } -# } -# ) + # ATOMIC REPLACE (NO PUSH) + media_clicks_col.update_one( + {"userId": user_oid}, + { + "$set": { + "ai_edit_daily_count": final_daily_entries, + "ai_edit_last_date": now, + "updatedAt": now + } + } + ) -# # -------------------------------------------------- -# # CATEGORY CLICK LOGIC -# # -------------------------------------------------- -# update_res = media_clicks_col.update_one( -# {"userId": user_oid, "categories.categoryId": category_oid}, -# { -# "$set": { -# "updatedAt": now, -# "categories.$.lastClickedAt": now -# }, -# "$inc": { -# "categories.$.click_count": 1 -# } -# } -# ) + # -------------------------------------------------- + # CATEGORY CLICK LOGIC + # -------------------------------------------------- + update_res = media_clicks_col.update_one( + {"userId": user_oid, "categories.categoryId": category_oid}, + { + "$set": { + "updatedAt": now, + "categories.$.lastClickedAt": now + }, + "$inc": { + "categories.$.click_count": 1 + } + } + ) -# # If category does not exist → push new -# if update_res.matched_count == 0: -# media_clicks_col.update_one( -# {"userId": user_oid}, -# { -# "$set": {"updatedAt": now}, -# "$push": { -# "categories": { -# "categoryId": category_oid, -# "click_count": 1, -# "lastClickedAt": now -# } -# } -# }, -# upsert=True -# ) + # If category does not exist → push new + if update_res.matched_count == 0: + media_clicks_col.update_one( + {"userId": user_oid}, + { + "$set": {"updatedAt": now}, + "$push": { + "categories": { + "categoryId": category_oid, + "click_count": 1, + "lastClickedAt": now + } + } + }, + upsert=True + ) -# except Exception as e: -# print("CATEGORY_LOG_ERROR:", e) -# # ------------------------- -# # 4. HF INFERENCE -# # ------------------------- -# try: -# # -------------------------------------------------- -# # MODEL OVERRIDE BASED ON CATEGORY -# # -------------------------------------------------- -# force_model = None + except Exception as e: + print("CATEGORY_LOG_ERROR:", e) + # ------------------------- + # 4. HF INFERENCE + # ------------------------- + try: + # -------------------------------------------------- + # MODEL OVERRIDE BASED ON CATEGORY + # -------------------------------------------------- + force_model = None -# if category_id == GEMINI_FORCE_CATEGORY_ID: -# force_model = "GEMINI" + if category_id == GEMINI_FORCE_CATEGORY_ID: + force_model = "GEMINI" -# pil_output = run_image_generation( -# image1=pil_img1, -# image2=pil_img2, -# prompt=prompt, -# force_model="GEMINI" if category_id == GEMINI_FORCE_CATEGORY_ID else None -# ) + pil_output = run_image_generation( + image1=pil_img1, + image2=pil_img2, + prompt=prompt, + force_model="GEMINI" if category_id == GEMINI_FORCE_CATEGORY_ID else None + ) -# except Exception as e: -# response_time_ms = round((time.time() - start_time) * 1000) -# logs_collection.insert_one({ -# "timestamp": datetime.utcnow(), -# "status": "failure", -# "input1_id": input1_id, -# "input2_id": input2_id if input2_id else None, -# "prompt": prompt, -# "user_email": user.get("email"), -# "error": str(e), -# "response_time_ms": response_time_ms, -# "appname": appname -# }) -# raise HTTPException(500, f"Inference failed: {e}") + except Exception as e: + response_time_ms = round((time.time() - start_time) * 1000) + logs_collection.insert_one({ + "timestamp": datetime.utcnow(), + "status": "failure", + "input1_id": input1_id, + "input2_id": input2_id if input2_id else None, + "prompt": prompt, + "user_email": user.get("email"), + "error": str(e), + "response_time_ms": response_time_ms, + "appname": appname + }) + raise HTTPException(500, f"Inference failed: {e}") -# # ------------------------- -# # 5. SAVE OUTPUT IMAGE -# # ------------------------- -# output_image_id = generate_image_id() -# out_buf = io.BytesIO() -# pil_output.save(out_buf, format="PNG") -# out_bytes = out_buf.getvalue() + # ------------------------- + # 5. SAVE OUTPUT IMAGE + # ------------------------- + output_image_id = generate_image_id() + out_buf = io.BytesIO() + pil_output.save(out_buf, format="PNG") + out_bytes = out_buf.getvalue() -# out_filename = f"{output_image_id}_result.png" -# out_url = upload_to_digitalocean(out_bytes, "results", out_filename) + out_filename = f"{output_image_id}_result.png" + out_url = upload_to_digitalocean(out_bytes, "results", out_filename) -# # ------------------------- -# # 5b. SAVE COMPRESSED IMAGE -# # ------------------------- -# compressed_bytes = compress_pil_image_to_2mb( -# pil_output, -# max_dim=1280 -# ) + # ------------------------- + # 5b. SAVE COMPRESSED IMAGE + # ------------------------- + compressed_bytes = compress_pil_image_to_2mb( + pil_output, + max_dim=1280 + ) -# compressed_filename = f"{output_image_id}_compressed.jpg" -# compressed_url = upload_to_digitalocean(compressed_bytes, "results", compressed_filename) + compressed_filename = f"{output_image_id}_compressed.jpg" + compressed_url = upload_to_digitalocean(compressed_bytes, "results", compressed_filename) -# response_time_ms = round((time.time() - start_time) * 1000) + response_time_ms = round((time.time() - start_time) * 1000) -# # ------------------------- -# # 6. LOG SUCCESS -# # ------------------------- -# logs_collection.insert_one({ -# "timestamp": datetime.utcnow(), -# "status": "success", -# "input1_id": input1_id, -# "input2_id": input2_id if input2_id else None, -# "output_id": output_image_id, -# "prompt": prompt, -# "user_email": user.get("email"), -# "response_time_ms": response_time_ms, -# "appname": appname -# }) + # ------------------------- + # 6. LOG SUCCESS + # ------------------------- + logs_collection.insert_one({ + "timestamp": datetime.utcnow(), + "status": "success", + "input1_id": input1_id, + "input2_id": input2_id if input2_id else None, + "output_id": output_image_id, + "prompt": prompt, + "user_email": user.get("email"), + "response_time_ms": response_time_ms, + "appname": appname + }) -# return JSONResponse({ -# "output_image_id": output_image_id, -# "user": user.get("email"), -# "response_time_ms": response_time_ms, -# "Compressed_Image_URL": compressed_url -# }) + return JSONResponse({ + "output_image_id": output_image_id, + "user": user.get("email"), + "response_time_ms": response_time_ms, + "Compressed_Image_URL": compressed_url + }) -# # Image endpoint removed - images are now stored directly in DigitalOcean Spaces -# # and are publicly accessible via the Compressed_Image_URL returned in the /generate response +# Image endpoint removed - images are now stored directly in DigitalOcean Spaces +# and are publicly accessible via the Compressed_Image_URL returned in the /generate response -# # --------------------------------------------------------------------- -# # Run locally -# # --------------------------------------------------------------------- -# if __name__ == "__main__": -# import uvicorn -# uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) -#################################################################################----------VERSION -1 CODE ----#################################################################### -# import os -# import io -# import json -# import traceback -# from datetime import datetime,timedelta -# from typing import Optional -# import time -# from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends -# from fastapi.responses import StreamingResponse, JSONResponse -# from fastapi.middleware.cors import CORSMiddleware -# from pydantic import BaseModel -# from pymongo import MongoClient -# import gridfs -# from bson.objectid import ObjectId -# from PIL import Image -# from fastapi.concurrency import run_in_threadpool -# import shutil -# import firebase_admin -# from firebase_admin import credentials, auth -# from PIL import Image -# from huggingface_hub import InferenceClient -# # --------------------------------------------------------------------- -# # Load Firebase Config from env (stringified JSON) -# # --------------------------------------------------------------------- -# firebase_config_json = os.getenv("firebase_config") -# if not firebase_config_json: -# raise RuntimeError("❌ Missing Firebase config in environment variable 'firebase_config'") +# --------------------------------------------------------------------- +# Run locally +# --------------------------------------------------------------------- +if __name__ == "__main__": + import uvicorn + uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) +################################################################################----------VERSION -1 CODE ----#################################################################### +import os +import io +import json +import traceback +from datetime import datetime,timedelta +from typing import Optional +import time +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from pymongo import MongoClient +import gridfs +from bson.objectid import ObjectId +from PIL import Image +from fastapi.concurrency import run_in_threadpool +import shutil +import firebase_admin +from firebase_admin import credentials, auth +from PIL import Image +from huggingface_hub import InferenceClient +# --------------------------------------------------------------------- +# Load Firebase Config from env (stringified JSON) +# --------------------------------------------------------------------- +firebase_config_json = os.getenv("firebase_config") +if not firebase_config_json: + raise RuntimeError("❌ Missing Firebase config in environment variable 'firebase_config'") -# try: -# firebase_creds_dict = json.loads(firebase_config_json) -# cred = credentials.Certificate(firebase_creds_dict) -# firebase_admin.initialize_app(cred) -# except Exception as e: -# raise RuntimeError(f"Failed to initialize Firebase Admin SDK: {e}") +try: + firebase_creds_dict = json.loads(firebase_config_json) + cred = credentials.Certificate(firebase_creds_dict) + firebase_admin.initialize_app(cred) +except Exception as e: + raise RuntimeError(f"Failed to initialize Firebase Admin SDK: {e}") -# # --------------------------------------------------------------------- -# # Hugging Face setup -# # --------------------------------------------------------------------- -# HF_TOKEN = os.getenv("HF_TOKEN") -# if not HF_TOKEN: -# raise RuntimeError("HF_TOKEN not set in environment variables") +# --------------------------------------------------------------------- +# Hugging Face setup +# --------------------------------------------------------------------- +HF_TOKEN = os.getenv("HF_TOKEN") +if not HF_TOKEN: + raise RuntimeError("HF_TOKEN not set in environment variables") -# hf_client = InferenceClient(token=HF_TOKEN) -# # --------------------------------------------------------------------- -# # MODEL SELECTION -# # --------------------------------------------------------------------- -# genai = None # ✅ IMPORTANT: module-level declaration -# MODEL = os.getenv("IMAGE_MODEL", "GEMINI").upper() -# GEMINI_FORCE_CATEGORY_ID = "69368e741224bcb6bdb98076" -# # --------------------------------------------------------------------- -# # Gemini setup (ONLY used if MODEL == "GEMINI") -# # --------------------------------------------------------------------- -# GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") -# GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image") +hf_client = InferenceClient(token=HF_TOKEN) +# --------------------------------------------------------------------- +# MODEL SELECTION +# --------------------------------------------------------------------- +genai = None # ✅ IMPORTANT: module-level declaration +MODEL = os.getenv("IMAGE_MODEL", "GEMINI").upper() +GEMINI_FORCE_CATEGORY_ID = "69368e741224bcb6bdb98076" +# --------------------------------------------------------------------- +# Gemini setup (ONLY used if MODEL == "GEMINI") +# --------------------------------------------------------------------- +GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") +GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-2.5-flash-image") -# # --------------------------------------------------------------------- -# # MongoDB setup -# # --------------------------------------------------------------------- -# MONGODB_URI=os.getenv("MONGODB_URI") -# DB_NAME = "polaroid_db" +# --------------------------------------------------------------------- +# MongoDB setup +# --------------------------------------------------------------------- +MONGODB_URI=os.getenv("MONGODB_URI") +DB_NAME = "polaroid_db" -# mongo = MongoClient(MONGODB_URI) -# db = mongo[DB_NAME] -# fs = gridfs.GridFS(db) -# logs_collection = db["logs"] +mongo = MongoClient(MONGODB_URI) +db = mongo[DB_NAME] +fs = gridfs.GridFS(db) +logs_collection = db["logs"] -# # --------------------------------------------------------------------- -# # FastAPI app setup -# # --------------------------------------------------------------------- -# app = FastAPI(title="Qwen Image Edit API with Firebase Auth") -# app.add_middleware( -# CORSMiddleware, -# allow_origins=["*"], -# allow_credentials=True, -# allow_methods=["*"], -# allow_headers=["*"], -# ) +# --------------------------------------------------------------------- +# FastAPI app setup +# --------------------------------------------------------------------- +app = FastAPI(title="Qwen Image Edit API with Firebase Auth") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) -# # --------------------------------------------------------------------- -# # Auth dependency -# # --------------------------------------------------------------------- -# async def verify_firebase_token(request: Request): -# """Middleware-like dependency to verify Firebase JWT from Authorization header.""" -# auth_header = request.headers.get("Authorization") -# if not auth_header or not auth_header.startswith("Bearer "): -# raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") +# --------------------------------------------------------------------- +# Auth dependency +# --------------------------------------------------------------------- +async def verify_firebase_token(request: Request): + """Middleware-like dependency to verify Firebase JWT from Authorization header.""" + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") -# id_token = auth_header.split("Bearer ")[1] -# try: -# decoded_token = auth.verify_id_token(id_token) -# request.state.user = decoded_token -# return decoded_token -# except Exception as e: -# raise HTTPException(status_code=401, detail=f"Invalid or expired Firebase token: {e}") + id_token = auth_header.split("Bearer ")[1] + try: + decoded_token = auth.verify_id_token(id_token) + request.state.user = decoded_token + return decoded_token + except Exception as e: + raise HTTPException(status_code=401, detail=f"Invalid or expired Firebase token: {e}") -# # --------------------------------------------------------------------- -# # Models -# # --------------------------------------------------------------------- -# class HealthResponse(BaseModel): -# status: str -# db: str -# model: str +# --------------------------------------------------------------------- +# Models +# --------------------------------------------------------------------- +class HealthResponse(BaseModel): + status: str + db: str + model: str -# # --------------------- UTILS --------------------- -# def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image: -# """ -# Resize image to fit within max_size while keeping aspect ratio. -# """ -# if img.width > max_size[0] or img.height > max_size[1]: -# img.thumbnail(max_size, Image.ANTIALIAS) -# return img -# # --------------------------------------------------------------------- -# # Lazy Gemini Initialization -# # --------------------------------------------------------------------- -# _genai_initialized = False -# def init_gemini(): -# global _genai_initialized, genai +# --------------------- UTILS --------------------- +def resize_image_if_needed(img: Image.Image, max_size=(1024, 1024)) -> Image.Image: + """ + Resize image to fit within max_size while keeping aspect ratio. + """ + if img.width > max_size[0] or img.height > max_size[1]: + img.thumbnail(max_size, Image.ANTIALIAS) + return img +# --------------------------------------------------------------------- +# Lazy Gemini Initialization +# --------------------------------------------------------------------- +_genai_initialized = False +def init_gemini(): + global _genai_initialized, genai -# if _genai_initialized: -# return + if _genai_initialized: + return -# if not GEMINI_API_KEY: -# raise RuntimeError("❌ GEMINI_API_KEY not set") + if not GEMINI_API_KEY: + raise RuntimeError("❌ GEMINI_API_KEY not set") -# import google.generativeai as genai -# genai.configure(api_key=GEMINI_API_KEY) + import google.generativeai as genai + genai.configure(api_key=GEMINI_API_KEY) -# _genai_initialized = True + _genai_initialized = True -# def expiry(hours: int): -# """Return UTC datetime for TTL expiration""" -# return datetime.utcnow() + timedelta(hours=hours) +def expiry(hours: int): + """Return UTC datetime for TTL expiration""" + return datetime.utcnow() + timedelta(hours=hours) -# def prepare_image(file_bytes: bytes) -> Image.Image: -# """ -# Open image and resize if larger than 1024x1024 -# """ -# img = Image.open(io.BytesIO(file_bytes)).convert("RGB") +def prepare_image(file_bytes: bytes) -> Image.Image: + """ + Open image and resize if larger than 1024x1024 + """ + img = Image.open(io.BytesIO(file_bytes)).convert("RGB") -# # ✅ MIN SIZE CHECK -# if img.width < 200 or img.height < 200: -# raise HTTPException( -# status_code=400, -# detail="Image size is below 200x200 pixels. Please upload a larger image." -# ) -# img = Image.open(io.BytesIO(file_bytes)).convert("RGB") -# img = resize_image_if_needed(img, max_size=(1024, 1024)) -# return img + # ✅ MIN SIZE CHECK + if img.width < 200 or img.height < 200: + raise HTTPException( + status_code=400, + detail="Image size is below 200x200 pixels. Please upload a larger image." + ) + img = Image.open(io.BytesIO(file_bytes)).convert("RGB") + img = resize_image_if_needed(img, max_size=(1024, 1024)) + return img -# MAX_COMPRESSED_SIZE = 2 * 1024 * 1024 # 2 MB -# def compress_pil_image_to_2mb( -# pil_img: Image.Image, -# max_dim: int = 1280 -# ) -> bytes: -# """ -# Resize + compress PIL image to <= 2MB. -# Returns JPEG bytes. -# """ -# img = pil_img.convert("RGB") +MAX_COMPRESSED_SIZE = 2 * 1024 * 1024 # 2 MB +def compress_pil_image_to_2mb( + pil_img: Image.Image, + max_dim: int = 1280 +) -> bytes: + """ + Resize + compress PIL image to <= 2MB. + Returns JPEG bytes. + """ + img = pil_img.convert("RGB") -# # Resize (maintain aspect ratio) -# img.thumbnail((max_dim, max_dim), Image.LANCZOS) + # Resize (maintain aspect ratio) + img.thumbnail((max_dim, max_dim), Image.LANCZOS) -# quality = 85 -# buffer = io.BytesIO() + quality = 85 + buffer = io.BytesIO() -# while quality >= 40: -# buffer.seek(0) -# buffer.truncate() + while quality >= 40: + buffer.seek(0) + buffer.truncate() -# img.save( -# buffer, -# format="JPEG", -# quality=quality, -# optimize=True, -# progressive=True -# ) + img.save( + buffer, + format="JPEG", + quality=quality, + optimize=True, + progressive=True + ) -# if buffer.tell() <= MAX_COMPRESSED_SIZE: -# break + if buffer.tell() <= MAX_COMPRESSED_SIZE: + break -# quality -= 5 + quality -= 5 -# return buffer.getvalue() + return buffer.getvalue() -# def run_image_generation( -# image1: Image.Image, -# prompt: str, -# image2: Optional[Image.Image] = None, -# force_model: Optional[str] = None -# ) -> Image.Image: -# """ -# Unified image generation interface. -# QWEN -> merges images if image2 exists -# GEMINI -> passes images separately -# """ +def run_image_generation( + image1: Image.Image, + prompt: str, + image2: Optional[Image.Image] = None, + force_model: Optional[str] = None +) -> Image.Image: + """ + Unified image generation interface. + QWEN -> merges images if image2 exists + GEMINI -> passes images separately + """ -# effective_model = (force_model or MODEL).upper() + effective_model = (force_model or MODEL).upper() -# # ---------------- QWEN ---------------- -# if effective_model == "QWEN": + # ---------------- QWEN ---------------- + if effective_model == "QWEN": -# # ✅ Merge images ONLY for QWEN -# if image2: -# total_width = image1.width + image2.width -# max_height = max(image1.height, image2.height) -# merged = Image.new("RGB", (total_width, max_height)) -# merged.paste(image1, (0, 0)) -# merged.paste(image2, (image1.width, 0)) -# else: -# merged = image1 + # ✅ Merge images ONLY for QWEN + if image2: + total_width = image1.width + image2.width + max_height = max(image1.height, image2.height) + merged = Image.new("RGB", (total_width, max_height)) + merged.paste(image1, (0, 0)) + merged.paste(image2, (image1.width, 0)) + else: + merged = image1 -# return hf_client.image_to_image( -# image=merged, -# prompt=prompt, -# model="Qwen/Qwen-Image-Edit" -# ) + return hf_client.image_to_image( + image=merged, + prompt=prompt, + model="Qwen/Qwen-Image-Edit" + ) -# # ---------------- GEMINI ---------------- -# elif effective_model == "GEMINI": -# init_gemini() -# model = genai.GenerativeModel(GEMINI_IMAGE_MODEL) + # ---------------- GEMINI ---------------- + elif effective_model == "GEMINI": + init_gemini() + model = genai.GenerativeModel(GEMINI_IMAGE_MODEL) -# parts = [prompt] + parts = [prompt] -# def add_image(img: Image.Image): -# buf = io.BytesIO() -# img.save(buf, format="PNG") -# buf.seek(0) -# parts.append({ -# "mime_type": "image/png", -# "data": buf.getvalue() -# }) + def add_image(img: Image.Image): + buf = io.BytesIO() + img.save(buf, format="PNG") + buf.seek(0) + parts.append({ + "mime_type": "image/png", + "data": buf.getvalue() + }) -# add_image(image1) + add_image(image1) -# if image2: -# add_image(image2) + if image2: + add_image(image2) -# response = model.generate_content(parts) + response = model.generate_content(parts) -# # -------- SAFE IMAGE EXTRACTION -------- -# import base64 + # -------- SAFE IMAGE EXTRACTION -------- + import base64 -# image_bytes = None -# for candidate in response.candidates: -# for part in candidate.content.parts: -# if hasattr(part, "inline_data") and part.inline_data: -# data = part.inline_data.data -# image_bytes = ( -# data if isinstance(data, (bytes, bytearray)) -# else base64.b64decode(data) -# ) -# break -# if image_bytes: -# break + image_bytes = None + for candidate in response.candidates: + for part in candidate.content.parts: + if hasattr(part, "inline_data") and part.inline_data: + data = part.inline_data.data + image_bytes = ( + data if isinstance(data, (bytes, bytearray)) + else base64.b64decode(data) + ) + break + if image_bytes: + break -# if not image_bytes: -# raise RuntimeError("Gemini did not return an image") + if not image_bytes: + raise RuntimeError("Gemini did not return an image") -# img = Image.open(io.BytesIO(image_bytes)) -# img.verify() -# return Image.open(io.BytesIO(image_bytes)).convert("RGB") + img = Image.open(io.BytesIO(image_bytes)) + img.verify() + return Image.open(io.BytesIO(image_bytes)).convert("RGB") -# else: -# raise RuntimeError(f"Unsupported IMAGE_MODEL: {effective_model}") + else: + raise RuntimeError(f"Unsupported IMAGE_MODEL: {effective_model}") -# def get_admin_db(appname: Optional[str]): -# """ -# Returns (categories_collection, media_clicks_collection) -# based on appname -# """ -# # ---- Collage Maker ---- -# if appname == "collage-maker": -# collage_uri = os.getenv("COLLAGE_MAKER_DB_URL") -# if not collage_uri: -# raise RuntimeError("COLLAGE_MAKER_DB_URL not set") +def get_admin_db(appname: Optional[str]): + """ + Returns (categories_collection, media_clicks_collection) + based on appname + """ + # ---- Collage Maker ---- + if appname == "collage-maker": + collage_uri = os.getenv("COLLAGE_MAKER_DB_URL") + if not collage_uri: + raise RuntimeError("COLLAGE_MAKER_DB_URL not set") -# client = MongoClient(collage_uri) -# db = client["adminPanel"] -# return db.categories, db.media_clicks + client = MongoClient(collage_uri) + db = client["adminPanel"] + return db.categories, db.media_clicks -# # ---- AI ENHANCER ---- -# if appname == "AI-Enhancer": -# enhancer_uri = os.getenv("AI_ENHANCER_DB_URL") -# if not enhancer_uri: -# raise RuntimeError("AI_ENHANCER_DB_URL not set") + # ---- AI ENHANCER ---- + if appname == "AI-Enhancer": + enhancer_uri = os.getenv("AI_ENHANCER_DB_URL") + if not enhancer_uri: + raise RuntimeError("AI_ENHANCER_DB_URL not set") -# client = MongoClient(enhancer_uri) -# db = client["test"] -# return db.categories, db.media_clicks + client = MongoClient(enhancer_uri) + db = client["test"] + return db.categories, db.media_clicks -# # DEFAULT (existing behavior) -# admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) -# db = admin_client["adminPanel"] -# return db.categories, db.media_clicks + # DEFAULT (existing behavior) + admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) + db = admin_client["adminPanel"] + return db.categories, db.media_clicks -# # --------------------------------------------------------------------- -# # Endpoints -# # --------------------------------------------------------------------- -# @app.get("/") -# async def root(): -# """Root endpoint""" -# return { -# "success": True, -# "message": "Polaroid,Kiddo,Makeup,Hairstyle API", -# "data": { -# "version": "1.0.0", -# "Product Name":"Beauty Camera - GlowCam AI Studio", -# "Released By" : "LogicGo Infotech" -# } -# } +# --------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------- +@app.get("/") +async def root(): + """Root endpoint""" + return { + "success": True, + "message": "Polaroid,Kiddo,Makeup,Hairstyle API", + "data": { + "version": "1.0.0", + "Product Name":"Beauty Camera - GlowCam AI Studio", + "Released By" : "LogicGo Infotech" + } + } -# @app.get("/health", response_model=HealthResponse) -# def health(): -# """Public health check""" -# mongo.admin.command("ping") -# return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit") +@app.get("/health", response_model=HealthResponse) +def health(): + """Public health check""" + mongo.admin.command("ping") + return HealthResponse(status="ok", db=db.name, model="Qwen/Qwen-Image-Edit") -# @app.post("/generate") -# async def generate( -# prompt: str = Form(...), -# image1: UploadFile = File(...), -# image2: Optional[UploadFile] = File(None), -# user_id: Optional[str] = Form(None), -# category_id: Optional[str] = Form(None), -# appname: Optional[str] = Form(None), -# user=Depends(verify_firebase_token) -# ): -# start_time = time.time() +@app.post("/generate") +async def generate( + prompt: str = Form(...), + image1: UploadFile = File(...), + image2: Optional[UploadFile] = File(None), + user_id: Optional[str] = Form(None), + category_id: Optional[str] = Form(None), + appname: Optional[str] = Form(None), + user=Depends(verify_firebase_token) +): + start_time = time.time() -# # ------------------------- -# # 1. VALIDATE & READ IMAGES -# # ------------------------- -# try: -# img1_bytes = await image1.read() -# pil_img1 = prepare_image(img1_bytes) -# input1_id = fs.put( -# img1_bytes, -# filename=image1.filename, -# contentType=image1.content_type, -# metadata={"role": "input"}, -# expireAt=expiry(6) # delete after 6 hours -# ) -# except Exception as e: -# raise HTTPException(400, f"Failed to read first image: {e}") + # ------------------------- + # 1. VALIDATE & READ IMAGES + # ------------------------- + try: + img1_bytes = await image1.read() + pil_img1 = prepare_image(img1_bytes) + input1_id = fs.put( + img1_bytes, + filename=image1.filename, + contentType=image1.content_type, + metadata={"role": "input"}, + expireAt=expiry(6) # delete after 6 hours + ) + except Exception as e: + raise HTTPException(400, f"Failed to read first image: {e}") -# img2_bytes = None -# input2_id = None -# pil_img2 = None + img2_bytes = None + input2_id = None + pil_img2 = None -# if image2: -# try: -# img2_bytes = await image2.read() -# pil_img2 = prepare_image(img2_bytes) -# input2_id = fs.put( -# img2_bytes, -# filename=image2.filename, -# contentType=image2.content_type, -# metadata={"role": "input"}, -# expireAt=expiry(6) -# ) -# except Exception as e: -# raise HTTPException(400, f"Failed to read second image: {e}") + if image2: + try: + img2_bytes = await image2.read() + pil_img2 = prepare_image(img2_bytes) + input2_id = fs.put( + img2_bytes, + filename=image2.filename, + contentType=image2.content_type, + metadata={"role": "input"}, + expireAt=expiry(6) + ) + except Exception as e: + raise HTTPException(400, f"Failed to read second image: {e}") -# # ------------------------- -# # 3. CATEGORY CLICK LOGIC -# # ------------------------- -# if user_id and category_id: -# try: -# admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) -# admin_db = admin_client["adminPanel"] + # ------------------------- + # 3. CATEGORY CLICK LOGIC + # ------------------------- + if user_id and category_id: + try: + admin_client = MongoClient(os.getenv("ADMIN_MONGODB_URI")) + admin_db = admin_client["adminPanel"] -# categories_col, media_clicks_col = get_admin_db(appname) -# # categories_col = admin_db.categories -# # media_clicks_col = admin_db.media_clicks + categories_col, media_clicks_col = get_admin_db(appname) + # categories_col = admin_db.categories + # media_clicks_col = admin_db.media_clicks -# # Validate user_oid & category_oid -# user_oid = ObjectId(user_id) -# category_oid = ObjectId(category_id) + # Validate user_oid & category_oid + user_oid = ObjectId(user_id) + category_oid = ObjectId(category_id) -# # Check category exists -# category_doc = categories_col.find_one({"_id": category_oid}) -# if not category_doc: -# raise HTTPException(400, f"Invalid category_id: {category_id}") + # Check category exists + category_doc = categories_col.find_one({"_id": category_oid}) + if not category_doc: + raise HTTPException(400, f"Invalid category_id: {category_id}") -# now = datetime.utcnow() + now = datetime.utcnow() -# # Normalize dates (UTC midnight) -# today_date = datetime(now.year, now.month, now.day) -# yesterday_date = today_date - timedelta(days=1) + # Normalize dates (UTC midnight) + today_date = datetime(now.year, now.month, now.day) + yesterday_date = today_date - timedelta(days=1) -# # -------------------------------------------------- -# # AI EDIT USAGE TRACKING (GLOBAL PER USER) -# # -------------------------------------------------- -# media_clicks_col.update_one( -# {"userId": user_oid}, -# { -# "$setOnInsert": { -# "createdAt": now, -# "ai_edit_daily_count": [] -# }, -# "$set": { -# "ai_edit_last_date": now, -# "updatedAt": now -# }, -# "$inc": { -# "ai_edit_complete": 1 -# } -# }, -# upsert=True -# ) + # -------------------------------------------------- + # AI EDIT USAGE TRACKING (GLOBAL PER USER) + # -------------------------------------------------- + media_clicks_col.update_one( + {"userId": user_oid}, + { + "$setOnInsert": { + "createdAt": now, + "ai_edit_daily_count": [] + }, + "$set": { + "ai_edit_last_date": now, + "updatedAt": now + }, + "$inc": { + "ai_edit_complete": 1 + } + }, + upsert=True + ) -# # -------------------------------------------------- -# # DAILY COUNT LOGIC -# # -------------------------------------------------- -# now = datetime.utcnow() -# today_date = datetime(now.year, now.month, now.day) + # -------------------------------------------------- + # DAILY COUNT LOGIC + # -------------------------------------------------- + now = datetime.utcnow() + today_date = datetime(now.year, now.month, now.day) -# doc = media_clicks_col.find_one( -# {"userId": user_oid}, -# {"ai_edit_daily_count": 1} -# ) + doc = media_clicks_col.find_one( + {"userId": user_oid}, + {"ai_edit_daily_count": 1} + ) -# daily_entries = doc.get("ai_edit_daily_count", []) if doc else [] + daily_entries = doc.get("ai_edit_daily_count", []) if doc else [] -# # Build UNIQUE date -> count map -# daily_map = {} -# for entry in daily_entries: -# d = entry["date"] -# d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d -# daily_map[d] = entry["count"] # overwrite = no duplicates + # Build UNIQUE date -> count map + daily_map = {} + for entry in daily_entries: + d = entry["date"] + d = datetime(d.year, d.month, d.day) if isinstance(d, datetime) else d + daily_map[d] = entry["count"] # overwrite = no duplicates -# # Find last known date -# last_date = max(daily_map.keys()) if daily_map else today_date + # Find last known date + last_date = max(daily_map.keys()) if daily_map else today_date -# # Fill ALL missing days with 0 -# next_day = last_date + timedelta(days=1) -# while next_day < today_date: -# daily_map.setdefault(next_day, 0) -# next_day += timedelta(days=1) + # Fill ALL missing days with 0 + next_day = last_date + timedelta(days=1) + while next_day < today_date: + daily_map.setdefault(next_day, 0) + next_day += timedelta(days=1) -# # Mark today as used (binary) -# daily_map[today_date] = 1 + # Mark today as used (binary) + daily_map[today_date] = 1 -# # Rebuild list (OLD → NEW) -# final_daily_entries = [ -# {"date": d, "count": daily_map[d]} -# for d in sorted(daily_map.keys()) -# ] + # Rebuild list (OLD → NEW) + final_daily_entries = [ + {"date": d, "count": daily_map[d]} + for d in sorted(daily_map.keys()) + ] -# # Keep last 32 days only -# final_daily_entries = final_daily_entries[-32:] + # Keep last 32 days only + final_daily_entries = final_daily_entries[-32:] -# # ATOMIC REPLACE (NO PUSH) -# media_clicks_col.update_one( -# {"userId": user_oid}, -# { -# "$set": { -# "ai_edit_daily_count": final_daily_entries, -# "ai_edit_last_date": now, -# "updatedAt": now -# } -# } -# ) + # ATOMIC REPLACE (NO PUSH) + media_clicks_col.update_one( + {"userId": user_oid}, + { + "$set": { + "ai_edit_daily_count": final_daily_entries, + "ai_edit_last_date": now, + "updatedAt": now + } + } + ) -# # -------------------------------------------------- -# # CATEGORY CLICK LOGIC -# # -------------------------------------------------- -# update_res = media_clicks_col.update_one( -# {"userId": user_oid, "categories.categoryId": category_oid}, -# { -# "$set": { -# "updatedAt": now, -# "categories.$.lastClickedAt": now -# }, -# "$inc": { -# "categories.$.click_count": 1 -# } -# } -# ) + # -------------------------------------------------- + # CATEGORY CLICK LOGIC + # -------------------------------------------------- + update_res = media_clicks_col.update_one( + {"userId": user_oid, "categories.categoryId": category_oid}, + { + "$set": { + "updatedAt": now, + "categories.$.lastClickedAt": now + }, + "$inc": { + "categories.$.click_count": 1 + } + } + ) -# # If category does not exist → push new -# if update_res.matched_count == 0: -# media_clicks_col.update_one( -# {"userId": user_oid}, -# { -# "$set": {"updatedAt": now}, -# "$push": { -# "categories": { -# "categoryId": category_oid, -# "click_count": 1, -# "lastClickedAt": now -# } -# } -# }, -# upsert=True -# ) + # If category does not exist → push new + if update_res.matched_count == 0: + media_clicks_col.update_one( + {"userId": user_oid}, + { + "$set": {"updatedAt": now}, + "$push": { + "categories": { + "categoryId": category_oid, + "click_count": 1, + "lastClickedAt": now + } + } + }, + upsert=True + ) -# except Exception as e: -# print("CATEGORY_LOG_ERROR:", e) -# # ------------------------- -# # 4. HF INFERENCE -# # ------------------------- -# try: -# # -------------------------------------------------- -# # MODEL OVERRIDE BASED ON CATEGORY -# # -------------------------------------------------- -# force_model = None + except Exception as e: + print("CATEGORY_LOG_ERROR:", e) + # ------------------------- + # 4. HF INFERENCE + # ------------------------- + try: + # -------------------------------------------------- + # MODEL OVERRIDE BASED ON CATEGORY + # -------------------------------------------------- + force_model = None -# if category_id == GEMINI_FORCE_CATEGORY_ID: -# force_model = "GEMINI" + if category_id == GEMINI_FORCE_CATEGORY_ID: + force_model = "GEMINI" -# pil_output = run_image_generation( -# image1=pil_img1, -# image2=pil_img2, -# prompt=prompt, -# force_model="GEMINI" if category_id == GEMINI_FORCE_CATEGORY_ID else None -# ) + pil_output = run_image_generation( + image1=pil_img1, + image2=pil_img2, + prompt=prompt, + force_model="GEMINI" if category_id == GEMINI_FORCE_CATEGORY_ID else None + ) -# except Exception as e: -# response_time_ms = round((time.time() - start_time) * 1000) -# logs_collection.insert_one({ -# "timestamp": datetime.utcnow(), -# "status": "failure", -# "input1_id": str(input1_id), -# "input2_id": str(input2_id) if input2_id else None, -# "prompt": prompt, -# "user_email": user.get("email"), -# "error": str(e), -# "response_time_ms": response_time_ms, -# "appname": appname -# }) -# raise HTTPException(500, f"Inference failed: {e}") + except Exception as e: + response_time_ms = round((time.time() - start_time) * 1000) + logs_collection.insert_one({ + "timestamp": datetime.utcnow(), + "status": "failure", + "input1_id": str(input1_id), + "input2_id": str(input2_id) if input2_id else None, + "prompt": prompt, + "user_email": user.get("email"), + "error": str(e), + "response_time_ms": response_time_ms, + "appname": appname + }) + raise HTTPException(500, f"Inference failed: {e}") -# # ------------------------- -# # 5. SAVE OUTPUT IMAGE -# # ------------------------- -# out_buf = io.BytesIO() -# pil_output.save(out_buf, format="PNG") -# out_bytes = out_buf.getvalue() + # ------------------------- + # 5. SAVE OUTPUT IMAGE + # ------------------------- + out_buf = io.BytesIO() + pil_output.save(out_buf, format="PNG") + out_bytes = out_buf.getvalue() -# out_id = fs.put( -# out_bytes, -# filename=f"result_{input1_id}.png", -# contentType="image/png", -# metadata={ -# "role": "output", -# "prompt": prompt, -# "input1_id": str(input1_id), -# "input2_id": str(input2_id) if input2_id else None, -# "user_email": user.get("email"), -# }, -# expireAt=expiry(24) # 24 hours -# ) -# # ------------------------- -# # 5b. SAVE COMPRESSED IMAGE -# # ------------------------- -# compressed_bytes = compress_pil_image_to_2mb( -# pil_output, -# max_dim=1280 -# ) + out_id = fs.put( + out_bytes, + filename=f"result_{input1_id}.png", + contentType="image/png", + metadata={ + "role": "output", + "prompt": prompt, + "input1_id": str(input1_id), + "input2_id": str(input2_id) if input2_id else None, + "user_email": user.get("email"), + }, + expireAt=expiry(24) # 24 hours + ) + # ------------------------- + # 5b. SAVE COMPRESSED IMAGE + # ------------------------- + compressed_bytes = compress_pil_image_to_2mb( + pil_output, + max_dim=1280 + ) -# compressed_id = fs.put( -# compressed_bytes, -# filename=f"result_{input1_id}_compressed.jpg", -# contentType="image/jpeg", -# metadata={ -# "role": "output_compressed", -# "original_output_id": str(out_id), -# "prompt": prompt, -# "user_email": user.get("email") -# }, -# expireAt=expiry(24) # 24 hours -# ) + compressed_id = fs.put( + compressed_bytes, + filename=f"result_{input1_id}_compressed.jpg", + contentType="image/jpeg", + metadata={ + "role": "output_compressed", + "original_output_id": str(out_id), + "prompt": prompt, + "user_email": user.get("email") + }, + expireAt=expiry(24) # 24 hours + ) -# response_time_ms = round((time.time() - start_time) * 1000) + response_time_ms = round((time.time() - start_time) * 1000) -# # ------------------------- -# # 6. LOG SUCCESS -# # ------------------------- -# logs_collection.insert_one({ -# "timestamp": datetime.utcnow(), -# "status": "success", -# "input1_id": str(input1_id), -# "input2_id": str(input2_id) if input2_id else None, -# "output_id": str(out_id), -# "prompt": prompt, -# "user_email": user.get("email"), -# "response_time_ms": response_time_ms, -# "appname": appname -# }) + # ------------------------- + # 6. LOG SUCCESS + # ------------------------- + logs_collection.insert_one({ + "timestamp": datetime.utcnow(), + "status": "success", + "input1_id": str(input1_id), + "input2_id": str(input2_id) if input2_id else None, + "output_id": str(out_id), + "prompt": prompt, + "user_email": user.get("email"), + "response_time_ms": response_time_ms, + "appname": appname + }) -# return JSONResponse({ -# "output_image_id": str(out_id), -# "user": user.get("email"), -# "response_time_ms": response_time_ms, -# "Compressed_Image_URL": ( -# f"https://logicgoinfotechspaces-polaroidimage.hf.space/image/{compressed_id}" -# ) -# }) + return JSONResponse({ + "output_image_id": str(out_id), + "user": user.get("email"), + "response_time_ms": response_time_ms, + "Compressed_Image_URL": ( + f"https://logicgoinfotechspaces-polaroidimage.hf.space/image/{compressed_id}" + ) + }) -# @app.get("/image/{image_id}") -# def get_image(image_id: str, download: Optional[bool] = False): -# """Retrieve stored image by ID (no authentication required).""" -# try: -# oid = ObjectId(image_id) -# grid_out = fs.get(oid) -# except Exception: -# raise HTTPException(status_code=404, detail="Image not found") +@app.get("/image/{image_id}") +def get_image(image_id: str, download: Optional[bool] = False): + """Retrieve stored image by ID (no authentication required).""" + try: + oid = ObjectId(image_id) + grid_out = fs.get(oid) + except Exception: + raise HTTPException(status_code=404, detail="Image not found") -# def iterfile(): -# yield grid_out.read() + def iterfile(): + yield grid_out.read() -# headers = {} -# if download: -# headers["Content-Disposition"] = f'attachment; filename="{grid_out.filename}"' + headers = {} + if download: + headers["Content-Disposition"] = f'attachment; filename="{grid_out.filename}"' -# return StreamingResponse( -# iterfile(), -# media_type=grid_out.content_type or "application/octet-stream", -# headers=headers -# ) + return StreamingResponse( + iterfile(), + media_type=grid_out.content_type or "application/octet-stream", + headers=headers + ) -# # --------------------------------------------------------------------- -# # Run locally -# # --------------------------------------------------------------------- -# if __name__ == "__main__": -# import uvicorn -# uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) +# --------------------------------------------------------------------- +# Run locally +# --------------------------------------------------------------------- +if __name__ == "__main__": + import uvicorn + uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)