Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import logging | |
| from pathlib import Path | |
| from flask import Flask, render_template, request, redirect, url_for, send_from_directory, current_app, jsonify | |
| from werkzeug.utils import secure_filename | |
| # ML / image libs | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import pandas as pd | |
| # rembg and TF | |
| from rembg import remove | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing import image | |
| from tensorflow.keras.applications import VGG16 | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # ------------------------- | |
| # Config | |
| # ------------------------- | |
| BASE_DIR = Path(__file__).parent | |
| UPLOAD_DIR = BASE_DIR / "uploads" | |
| RESULT_DIR = BASE_DIR / "results" | |
| CAPTIONS_CSV = BASE_DIR / "aesthetic_instagram_captions.csv" | |
| ALLOWED_EXT = {"png", "jpg", "jpeg", "webp"} | |
| MAX_UPLOAD_SIZE = 12 * 1024 * 1024 # 12 MB | |
| # ensure dirs exist | |
| UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| RESULT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Flask app | |
| app = Flask(__name__, template_folder='templates') | |
| app.config['UPLOAD_FOLDER'] = str(UPLOAD_DIR) | |
| app.config['RESULT_FOLDER'] = str(RESULT_DIR) | |
| app.config['MAX_CONTENT_LENGTH'] = MAX_UPLOAD_SIZE | |
| # logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Optional: try to limit TF memory growth (avoids many OOMs) | |
| physical_devices = tf.config.list_physical_devices('GPU') | |
| for dev in physical_devices: | |
| try: | |
| tf.config.experimental.set_memory_growth(dev, True) | |
| except Exception as e: | |
| logger.warning("Could not set memory growth: %s", e) | |
| # ------------------------- | |
| # Helpers | |
| # ------------------------- | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXT | |
| def unique_secure_filename(original_name): | |
| name = secure_filename(original_name) | |
| uid = uuid.uuid4().hex | |
| return f"{uid}_{name}" | |
| # ------------------------- | |
| # Load captions CSV (defensive) | |
| # ------------------------- | |
| if not CAPTIONS_CSV.exists(): | |
| logger.error("Captions CSV not found at: %s", CAPTIONS_CSV) | |
| captions_df = None | |
| else: | |
| try: | |
| captions_df = pd.read_csv(CAPTIONS_CSV) | |
| if 'Captions' not in captions_df.columns: | |
| logger.error("CSV present but missing 'Captions' column. Columns: %s", captions_df.columns.tolist()) | |
| captions_df = None | |
| except Exception: | |
| logger.exception("Failed to read captions CSV") | |
| captions_df = None | |
| # ------------------------- | |
| # Load VGG model (defensive) | |
| # ------------------------- | |
| try: | |
| model = VGG16(weights='imagenet', include_top=False, pooling='max') | |
| logger.info("VGG16 loaded successfully") | |
| except Exception as e: | |
| logger.exception("Failed to load VGG16 model: %s", e) | |
| model = None | |
| # ------------------------- | |
| # Image & ML functions | |
| # ------------------------- | |
| def remove_background(image_path): | |
| try: | |
| inp = Image.open(image_path).convert("RGBA") | |
| out = remove(inp) # PIL Image back | |
| out = out.resize((512, 512)) | |
| output_path = str(Path(image_path).with_suffix('') ) + '_no_bg.png' | |
| out.save(output_path) | |
| return output_path | |
| except Exception: | |
| logger.exception("remove_background failed for %s", image_path) | |
| raise | |
| def change_background(foreground_path, background_path): | |
| try: | |
| fg = cv2.imread(str(foreground_path), cv2.IMREAD_UNCHANGED) | |
| bg = cv2.imread(str(background_path), cv2.IMREAD_COLOR) | |
| if fg is None or bg is None: | |
| raise ValueError("Failed to read foreground or background image with OpenCV") | |
| fg = cv2.resize(fg, (512, 512)) | |
| bg = cv2.resize(bg, (fg.shape[1], fg.shape[0])) | |
| # if FG has alpha channel, use it | |
| if fg.shape[2] == 4: | |
| alpha = fg[:, :, 3] | |
| mask = cv2.threshold(alpha, 1, 255, cv2.THRESH_BINARY)[1] | |
| else: | |
| gray = cv2.cvtColor(fg, cv2.COLOR_BGR2GRAY) | |
| mask = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)[1] | |
| mask_inv = cv2.bitwise_not(mask) | |
| bg_part = cv2.bitwise_and(bg, bg, mask=mask_inv) | |
| fg_part = cv2.bitwise_and(fg[:, :, :3], fg[:, :, :3], mask=mask) | |
| result = cv2.add(bg_part, fg_part) | |
| return result | |
| except Exception: | |
| logger.exception("change_background failed for %s + %s", foreground_path, background_path) | |
| raise | |
| def extract_features(img_path): | |
| if model is None: | |
| raise RuntimeError("Model is not loaded") | |
| try: | |
| img = image.load_img(img_path, target_size=(224, 224)) | |
| arr = image.img_to_array(img) | |
| arr = np.expand_dims(arr, axis=0) | |
| arr = tf.keras.applications.vgg16.preprocess_input(arr) | |
| feats = model.predict(arr) | |
| return feats # shape (1, 512) | |
| except Exception: | |
| logger.exception("extract_features failed for %s", img_path) | |
| raise | |
| def generate_caption_from_image(img_path): | |
| # minimal robust implementation — currently uses random caption features as placeholder | |
| if captions_df is None: | |
| raise RuntimeError("Captions dataset not loaded") | |
| img_feats = extract_features(img_path) # (1,512) | |
| best = None | |
| # NOTE: this is a placeholder similarity calculation. Replace with real caption embeddings if available. | |
| for cap in captions_df['Captions'].values: | |
| try: | |
| caption_feats = np.random.rand(1, img_feats.shape[1]) | |
| sim = cosine_similarity(img_feats, caption_feats) | |
| score = float(sim[0][0]) | |
| except Exception: | |
| score = 0.0 | |
| if best is None or score > best[1]: | |
| best = (cap, score) | |
| return best[0] if best else "" | |
| # ------------------------- | |
| # Routes | |
| # ------------------------- | |
| def index(): | |
| try: | |
| if request.method == 'POST': | |
| # Background change mode (both files present) | |
| if 'foreground' in request.files and 'background' in request.files: | |
| fore = request.files['foreground'] | |
| back = request.files['background'] | |
| if fore.filename == '' or back.filename == '': | |
| return render_template('index.html', error="Please select both foreground and background images.") | |
| if not allowed_file(fore.filename) or not allowed_file(back.filename): | |
| return render_template('index.html', error="Unsupported file type.") | |
| f_name = unique_secure_filename(fore.filename) | |
| b_name = unique_secure_filename(back.filename) | |
| f_path = UPLOAD_DIR / f_name | |
| b_path = UPLOAD_DIR / b_name | |
| fore.save(str(f_path)) | |
| back.save(str(b_path)) | |
| # remove bg and change | |
| fg_no_bg = remove_background(str(f_path)) | |
| result_img = change_background(fg_no_bg, str(b_path)) | |
| result_path = RESULT_DIR / "result.png" | |
| cv2.imwrite(str(result_path), result_img) | |
| return render_template('output.html', image_path=str(result_path)) | |
| # Caption generation mode (single image field named 'image') | |
| elif 'image' in request.files: | |
| imgfile = request.files['image'] | |
| if imgfile.filename == '': | |
| return render_template('index.html', error="Please select an image.") | |
| if not allowed_file(imgfile.filename): | |
| return render_template('index.html', error="Unsupported file type.") | |
| saved_name = unique_secure_filename(imgfile.filename) | |
| saved_path = UPLOAD_DIR / saved_name | |
| imgfile.save(str(saved_path)) | |
| # generate caption | |
| caption = generate_caption_from_image(str(saved_path)) | |
| return render_template('output.html', caption=caption) | |
| return render_template('index.html') | |
| except Exception: | |
| logger.exception("Unhandled error in index route") | |
| # render a friendly page with an error message | |
| return render_template('error.html', message="Internal server error. Check server logs for details."), 500 | |
| def uploaded_file(filename): | |
| # serve files for download/view | |
| safe = secure_filename(filename) | |
| full = UPLOAD_DIR / safe | |
| if not full.exists(): | |
| return "File not found", 404 | |
| return send_from_directory(app.config['UPLOAD_FOLDER'], safe) | |
| # static template routes (kept as before) | |
| def home(): | |
| return render_template('index.html') | |
| # run | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860, debug=False) | |