captiongenrater / app.py
deedrop1140's picture
Update app.py
1e7c60e verified
import os
import uuid
import logging
from pathlib import Path
from flask import Flask, render_template, request, redirect, url_for, send_from_directory, current_app, jsonify
from werkzeug.utils import secure_filename
# ML / image libs
import warnings
warnings.filterwarnings('ignore')
from PIL import Image
import numpy as np
import cv2
import pandas as pd
# rembg and TF
from rembg import remove
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import VGG16
from sklearn.metrics.pairwise import cosine_similarity
# -------------------------
# Config
# -------------------------
BASE_DIR = Path(__file__).parent
UPLOAD_DIR = BASE_DIR / "uploads"
RESULT_DIR = BASE_DIR / "results"
CAPTIONS_CSV = BASE_DIR / "aesthetic_instagram_captions.csv"
ALLOWED_EXT = {"png", "jpg", "jpeg", "webp"}
MAX_UPLOAD_SIZE = 12 * 1024 * 1024 # 12 MB
# ensure dirs exist
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
RESULT_DIR.mkdir(parents=True, exist_ok=True)
# Flask app
app = Flask(__name__, template_folder='templates')
app.config['UPLOAD_FOLDER'] = str(UPLOAD_DIR)
app.config['RESULT_FOLDER'] = str(RESULT_DIR)
app.config['MAX_CONTENT_LENGTH'] = MAX_UPLOAD_SIZE
# logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Optional: try to limit TF memory growth (avoids many OOMs)
physical_devices = tf.config.list_physical_devices('GPU')
for dev in physical_devices:
try:
tf.config.experimental.set_memory_growth(dev, True)
except Exception as e:
logger.warning("Could not set memory growth: %s", e)
# -------------------------
# Helpers
# -------------------------
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXT
def unique_secure_filename(original_name):
name = secure_filename(original_name)
uid = uuid.uuid4().hex
return f"{uid}_{name}"
# -------------------------
# Load captions CSV (defensive)
# -------------------------
if not CAPTIONS_CSV.exists():
logger.error("Captions CSV not found at: %s", CAPTIONS_CSV)
captions_df = None
else:
try:
captions_df = pd.read_csv(CAPTIONS_CSV)
if 'Captions' not in captions_df.columns:
logger.error("CSV present but missing 'Captions' column. Columns: %s", captions_df.columns.tolist())
captions_df = None
except Exception:
logger.exception("Failed to read captions CSV")
captions_df = None
# -------------------------
# Load VGG model (defensive)
# -------------------------
try:
model = VGG16(weights='imagenet', include_top=False, pooling='max')
logger.info("VGG16 loaded successfully")
except Exception as e:
logger.exception("Failed to load VGG16 model: %s", e)
model = None
# -------------------------
# Image & ML functions
# -------------------------
def remove_background(image_path):
try:
inp = Image.open(image_path).convert("RGBA")
out = remove(inp) # PIL Image back
out = out.resize((512, 512))
output_path = str(Path(image_path).with_suffix('') ) + '_no_bg.png'
out.save(output_path)
return output_path
except Exception:
logger.exception("remove_background failed for %s", image_path)
raise
def change_background(foreground_path, background_path):
try:
fg = cv2.imread(str(foreground_path), cv2.IMREAD_UNCHANGED)
bg = cv2.imread(str(background_path), cv2.IMREAD_COLOR)
if fg is None or bg is None:
raise ValueError("Failed to read foreground or background image with OpenCV")
fg = cv2.resize(fg, (512, 512))
bg = cv2.resize(bg, (fg.shape[1], fg.shape[0]))
# if FG has alpha channel, use it
if fg.shape[2] == 4:
alpha = fg[:, :, 3]
mask = cv2.threshold(alpha, 1, 255, cv2.THRESH_BINARY)[1]
else:
gray = cv2.cvtColor(fg, cv2.COLOR_BGR2GRAY)
mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)[1]
mask_inv = cv2.bitwise_not(mask)
bg_part = cv2.bitwise_and(bg, bg, mask=mask_inv)
fg_part = cv2.bitwise_and(fg[:, :, :3], fg[:, :, :3], mask=mask)
result = cv2.add(bg_part, fg_part)
return result
except Exception:
logger.exception("change_background failed for %s + %s", foreground_path, background_path)
raise
def extract_features(img_path):
if model is None:
raise RuntimeError("Model is not loaded")
try:
img = image.load_img(img_path, target_size=(224, 224))
arr = image.img_to_array(img)
arr = np.expand_dims(arr, axis=0)
arr = tf.keras.applications.vgg16.preprocess_input(arr)
feats = model.predict(arr)
return feats # shape (1, 512)
except Exception:
logger.exception("extract_features failed for %s", img_path)
raise
def generate_caption_from_image(img_path):
# minimal robust implementation — currently uses random caption features as placeholder
if captions_df is None:
raise RuntimeError("Captions dataset not loaded")
img_feats = extract_features(img_path) # (1,512)
best = None
# NOTE: this is a placeholder similarity calculation. Replace with real caption embeddings if available.
for cap in captions_df['Captions'].values:
try:
caption_feats = np.random.rand(1, img_feats.shape[1])
sim = cosine_similarity(img_feats, caption_feats)
score = float(sim[0][0])
except Exception:
score = 0.0
if best is None or score > best[1]:
best = (cap, score)
return best[0] if best else ""
# -------------------------
# Routes
# -------------------------
@app.route('/', methods=['GET', 'POST'])
def index():
try:
if request.method == 'POST':
# Background change mode (both files present)
if 'foreground' in request.files and 'background' in request.files:
fore = request.files['foreground']
back = request.files['background']
if fore.filename == '' or back.filename == '':
return render_template('index.html', error="Please select both foreground and background images.")
if not allowed_file(fore.filename) or not allowed_file(back.filename):
return render_template('index.html', error="Unsupported file type.")
f_name = unique_secure_filename(fore.filename)
b_name = unique_secure_filename(back.filename)
f_path = UPLOAD_DIR / f_name
b_path = UPLOAD_DIR / b_name
fore.save(str(f_path))
back.save(str(b_path))
# remove bg and change
fg_no_bg = remove_background(str(f_path))
result_img = change_background(fg_no_bg, str(b_path))
result_path = RESULT_DIR / "result.png"
cv2.imwrite(str(result_path), result_img)
return render_template('output.html', image_path=str(result_path))
# Caption generation mode (single image field named 'image')
elif 'image' in request.files:
imgfile = request.files['image']
if imgfile.filename == '':
return render_template('index.html', error="Please select an image.")
if not allowed_file(imgfile.filename):
return render_template('index.html', error="Unsupported file type.")
saved_name = unique_secure_filename(imgfile.filename)
saved_path = UPLOAD_DIR / saved_name
imgfile.save(str(saved_path))
# generate caption
caption = generate_caption_from_image(str(saved_path))
return render_template('output.html', caption=caption)
return render_template('index.html')
except Exception:
logger.exception("Unhandled error in index route")
# render a friendly page with an error message
return render_template('error.html', message="Internal server error. Check server logs for details."), 500
@app.route('/uploads/<path:filename>')
def uploaded_file(filename):
# serve files for download/view
safe = secure_filename(filename)
full = UPLOAD_DIR / safe
if not full.exists():
return "File not found", 404
return send_from_directory(app.config['UPLOAD_FOLDER'], safe)
# static template routes (kept as before)
@app.route('/index.html')
def home():
return render_template('index.html')
# run
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False)