|
|
import os |
|
|
import time |
|
|
from datetime import datetime |
|
|
from fastapi import FastAPI, UploadFile, Form |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from PIL import Image |
|
|
from PIL import ImageFilter, ImageOps |
|
|
from rembg import remove |
|
|
import google.generativeai as genai |
|
|
import gradio as gr |
|
|
import uvicorn |
|
|
from dotenv import load_dotenv |
|
|
import threading |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") |
|
|
print("Gemini API Key Loaded:", api_key is not None) |
|
|
|
|
|
if not api_key: |
|
|
raise RuntimeError("❌ No Gemini API Key found in .env. Please set GEMINI_API_KEY or GOOGLE_API_KEY.") |
|
|
|
|
|
genai.configure(api_key=api_key) |
|
|
model = genai.GenerativeModel("gemini-1.5-flash") |
|
|
|
|
|
UPLOAD_DIR = "uploads" |
|
|
RESULTS_DIR = "results" |
|
|
BG_DIR = "backgrounds" |
|
|
MAX_SIZE_MB = 5 |
|
|
LIFETIME = 24 * 60 * 60 |
|
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
os.makedirs(BG_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cleanup_old_files(folder): |
|
|
now = time.time() |
|
|
for f in os.listdir(folder): |
|
|
path = os.path.join(folder, f) |
|
|
if os.path.isfile(path) and now - os.path.getmtime(path) > LIFETIME: |
|
|
os.remove(path) |
|
|
|
|
|
def check_size(filepath): |
|
|
if os.path.getsize(filepath) > MAX_SIZE_MB * 1024 * 1024: |
|
|
os.remove(filepath) |
|
|
raise ValueError(f"File too large! Max {MAX_SIZE_MB}MB allowed.") |
|
|
|
|
|
def replace_background(input_path, bg_choice): |
|
|
"""Replace background with selected file""" |
|
|
check_size(input_path) |
|
|
input_img = Image.open(input_path).convert("RGBA") |
|
|
fg = remove(input_img) |
|
|
|
|
|
bg_path = os.path.join(BG_DIR, bg_choice) |
|
|
bg = Image.open(bg_path).convert("RGBA").resize(fg.size) |
|
|
|
|
|
result = Image.alpha_composite(bg, fg) |
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
|
result_path = os.path.join(RESULTS_DIR, f"result_{timestamp}.png") |
|
|
result.save(result_path) |
|
|
cleanup_old_files(RESULTS_DIR) |
|
|
return result_path |
|
|
|
|
|
|
|
|
def process_image(input_img, bg_choice, bg_upload, logo_upload, logo_transparency, logo_position, blur_background, blend_strength): |
|
|
if input_img is None: |
|
|
return [] |
|
|
|
|
|
temp_path = os.path.join(UPLOAD_DIR, f"upload_{int(time.time())}.png") |
|
|
input_img.save(temp_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if bg_upload is not None: |
|
|
bg = bg_upload.convert("RGBA").resize(input_img.size) |
|
|
else: |
|
|
bg_path = os.path.join(BG_DIR, bg_choice) |
|
|
bg = Image.open(bg_path).convert("RGBA").resize(input_img.size) |
|
|
|
|
|
|
|
|
if blur_background: |
|
|
bg = bg.filter(ImageFilter.GaussianBlur(radius=3)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fg = remove(input_img.convert("RGBA")) |
|
|
|
|
|
|
|
|
mask = fg.split()[3] |
|
|
|
|
|
|
|
|
mask = mask.filter(ImageFilter.GaussianBlur(radius=2)) |
|
|
fg.putalpha(mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from PIL import ImageStat, ImageEnhance |
|
|
|
|
|
|
|
|
stat_bg = ImageStat.Stat(bg.convert("L")) |
|
|
bg_brightness = stat_bg.mean[0] |
|
|
|
|
|
|
|
|
stat_fg = ImageStat.Stat(fg.convert("L")) |
|
|
fg_brightness = stat_fg.mean[0] |
|
|
|
|
|
if fg_brightness > 0: |
|
|
brightness_ratio = bg_brightness / fg_brightness |
|
|
|
|
|
|
|
|
enhancer = ImageEnhance.Brightness(fg) |
|
|
adjusted = enhancer.enhance(brightness_ratio) |
|
|
|
|
|
|
|
|
fg = Image.blend(fg, adjusted, alpha=blend_strength) |
|
|
except Exception as e: |
|
|
print("Color match failed:", e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = Image.alpha_composite(bg, fg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if logo_upload is not None: |
|
|
logo = logo_upload.convert("RGBA") |
|
|
scale = result.width // 5 |
|
|
logo.thumbnail((scale, scale)) |
|
|
|
|
|
|
|
|
alpha = logo.split()[3].point(lambda p: p * (logo_transparency / 100)) |
|
|
logo.putalpha(alpha) |
|
|
|
|
|
pos_map = { |
|
|
"Top-Left": (10, 10), |
|
|
"Top-Right": (result.width - logo.width - 10, 10), |
|
|
"Bottom-Left": (10, result.height - logo.height - 10), |
|
|
"Bottom-Right": (result.width - logo.width - 10, result.height - logo.height - 10), |
|
|
"Center": ((result.width - logo.width) // 2, (result.height - logo.height) // 2), |
|
|
} |
|
|
result.paste(logo, pos_map[logo_position], logo) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
|
result_path = os.path.join(RESULTS_DIR, f"result_{timestamp}.png") |
|
|
result.save(result_path) |
|
|
|
|
|
cleanup_old_files(RESULTS_DIR) |
|
|
return [result_path] |
|
|
|
|
|
def generate_caption(prompt="Promote my product"): |
|
|
try: |
|
|
full_prompt = ( |
|
|
f"Write 3 catchy marketing captions for social media about: {prompt}. " |
|
|
"Each caption should include persuasive language, emojis, and 3-5 relevant hashtags. " |
|
|
"Format output clearly as:\nInstagram:\nFacebook:\nTikTok:\n" |
|
|
) |
|
|
response = model.generate_content(full_prompt) |
|
|
return response.text.strip() |
|
|
except Exception as e: |
|
|
return f"❌ Error generating captions: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="SnapLift API") |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
@app.post("/process-image") |
|
|
async def process_image_api(file: UploadFile, bg_choice: str = Form(...)): |
|
|
try: |
|
|
input_path = os.path.join(UPLOAD_DIR, file.filename) |
|
|
with open(input_path, "wb") as f: |
|
|
f.write(await file.read()) |
|
|
|
|
|
result_path = replace_background(input_path, bg_choice) |
|
|
return FileResponse(result_path) |
|
|
except Exception as e: |
|
|
return JSONResponse(content={"error": str(e)}, status_code=400) |
|
|
|
|
|
@app.post("/generate-captions") |
|
|
async def generate_captions_api(prompt: str = Form(...)): |
|
|
captions = generate_caption(prompt) |
|
|
return {"captions": captions} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=""" |
|
|
footer {display:none !important} |
|
|
.gradio-container {max-width: 100% !important; font-family: 'Segoe UI', sans-serif;} |
|
|
h1, h2, h3, label {font-weight:600 !important;} |
|
|
.box {padding: 12px; border-radius: 15px; background: var(--block-background-fill); |
|
|
box-shadow: 0 2px 8px rgba(0,0,0,0.05); margin-bottom:12px;} |
|
|
#output-img img {width:100% !important; height:auto !important; border-radius:18px; |
|
|
box-shadow:0 4px 12px rgba(0,0,0,0.15);} |
|
|
""") as demo: |
|
|
|
|
|
|
|
|
|
|
|
theme_state = gr.State("light") |
|
|
|
|
|
def toggle_theme(current): |
|
|
return "dark" if current == "light" else "light" |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("<h1 style='text-align:center; font-size:2.2em;'>✨ SnapLift – AI Social Media Booster</h1>") |
|
|
gr.Markdown("<p style='text-align:center; font-size:1.1em; color:#555;'>Upload or capture your product photo, replace background, and auto-generate <b>marketing captions</b> + <b>hashtags</b>!</p>") |
|
|
with gr.Column(scale=0.2): |
|
|
theme_btn = gr.Button("🌙 Toggle Theme") |
|
|
|
|
|
theme_btn.click(fn=toggle_theme, inputs=theme_state, outputs=theme_state, queue=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("📸 Image Editor"): |
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes="box"): |
|
|
input_img = gr.Image( |
|
|
type="pil", |
|
|
label="📤 Upload or Capture Main Photo", |
|
|
sources=["upload", "webcam"], |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
with gr.Accordion("🎨 Background Options", open=True): |
|
|
bg_choices = gr.Dropdown( |
|
|
choices=os.listdir(BG_DIR) or ["default.png"], |
|
|
value=(os.listdir(BG_DIR)[0] if os.listdir(BG_DIR) else None), |
|
|
label="Choose Background" |
|
|
) |
|
|
bg_upload = gr.Image( |
|
|
type="pil", |
|
|
label="📤 Upload or Capture Background", |
|
|
sources=["upload", "webcam"], |
|
|
min_width=250 |
|
|
) |
|
|
bg_preview = gr.Image(type="pil", label="Background Preview", interactive=False) |
|
|
|
|
|
def load_bg(choice): |
|
|
if not choice: |
|
|
return None |
|
|
path = os.path.join(BG_DIR, choice) |
|
|
if os.path.exists(path): |
|
|
from PIL import Image |
|
|
return Image.open(path) |
|
|
return None |
|
|
bg_choices.change(fn=load_bg, inputs=bg_choices, outputs=bg_preview) |
|
|
|
|
|
with gr.Accordion("🏷️ Branding", open=False): |
|
|
logo_upload = gr.Image(type="pil", label="Upload Logo") |
|
|
logo_transparency = gr.Slider(0, 100, value=70, label="Logo Transparency (%)") |
|
|
logo_position = gr.Dropdown( |
|
|
["Top-Left", "Top-Right", "Bottom-Left", "Bottom-Right", "Center"], |
|
|
value="Bottom-Right", |
|
|
label="Logo Position" |
|
|
) |
|
|
|
|
|
with gr.Accordion("✨ Realism Settings", open=False): |
|
|
blend_strength = gr.Slider(0, 1, value=0.5, step=0.1, label="Blending Strength") |
|
|
blur_background = gr.Checkbox(label="Blur Background", value=False) |
|
|
|
|
|
with gr.Accordion("💾 Export Options", open=True): |
|
|
export_format = gr.Dropdown( |
|
|
["PNG", "JPG", "PDF"], |
|
|
value="PNG", |
|
|
label="Export Format" |
|
|
) |
|
|
export_size = gr.Dropdown( |
|
|
[ |
|
|
"Original", |
|
|
"Instagram (1080x1080)", |
|
|
"Facebook (1200x628)", |
|
|
"TikTok (1080x1920)", |
|
|
"LinkedIn (1200x1200)", |
|
|
"Twitter (1600x900)" |
|
|
], |
|
|
value="Original", |
|
|
label="Social Media Size" |
|
|
) |
|
|
|
|
|
btn = gr.Button("🚀 Generate & Export", elem_classes="box") |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 🖼️ Preview") |
|
|
output_img = gr.Image( |
|
|
type="filepath", |
|
|
label="Generated Image", |
|
|
elem_id="output-img", |
|
|
interactive=False |
|
|
) |
|
|
download_btn = gr.File(label="⬇️ Download HD Export") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_and_export(input_img, bg_choice, bg_upload, logo_upload, logo_transparency, logo_position, blur_background, blend_strength, export_format, export_size): |
|
|
result = process_image(input_img, bg_choice, bg_upload, logo_upload, logo_transparency, logo_position, blur_background, blend_strength) |
|
|
if isinstance(result, list): |
|
|
result = result[0] |
|
|
|
|
|
img = Image.open(result) |
|
|
|
|
|
|
|
|
size_map = { |
|
|
"Instagram (1080x1080)": (1080, 1080), |
|
|
"Facebook (1200x628)": (1200, 628), |
|
|
"TikTok (1080x1920)": (1080, 1920), |
|
|
"LinkedIn (1200x1200)": (1200, 1200), |
|
|
"Twitter (1600x900)": (1600, 900) |
|
|
} |
|
|
if export_size in size_map: |
|
|
img = img.resize(size_map[export_size], Image.LANCZOS) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
|
if export_format == "PNG": |
|
|
out_path = os.path.join(RESULTS_DIR, f"export_{timestamp}.png") |
|
|
img.save(out_path, "PNG", quality=95) |
|
|
elif export_format == "JPG": |
|
|
out_path = os.path.join(RESULTS_DIR, f"export_{timestamp}.jpg") |
|
|
img.convert("RGB").save(out_path, "JPEG", quality=95) |
|
|
elif export_format == "PDF": |
|
|
out_path = os.path.join(RESULTS_DIR, f"export_{timestamp}.pdf") |
|
|
img.convert("RGB").save(out_path, "PDF", resolution=300.0) |
|
|
else: |
|
|
out_path = result |
|
|
|
|
|
return result, out_path |
|
|
|
|
|
|
|
|
btn.click( |
|
|
fn=process_and_export, |
|
|
inputs=[input_img, bg_choices, bg_upload, logo_upload, logo_transparency, logo_position, blur_background, blend_strength, export_format, export_size], |
|
|
outputs=[output_img, download_btn] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("✍️ Caption Generator"): |
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(elem_classes="box"): |
|
|
prompt = gr.Textbox(label="📝 Enter product/promotion text", value="Promote my skincare product") |
|
|
btn2 = gr.Button("💡 Suggest Captions + Hashtags") |
|
|
with gr.Column(scale=1): |
|
|
caption_box = gr.Textbox(label="Suggested Posts (multi-platform)", lines=12) |
|
|
|
|
|
btn2.click(fn=generate_caption, inputs=[prompt], outputs=[caption_box]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
def run_gradio(): |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True) |
|
|
|
|
|
threading.Thread(target=run_gradio, daemon=True).start() |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|