Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| from flask import Flask, render_template, request, redirect, url_for, flash, send_file | |
| from flask_bcrypt import Bcrypt | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import onnxruntime | |
| from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering | |
| from werkzeug.utils import secure_filename | |
| import pandas as pd | |
| from duckduckgo_search import DDGS | |
| import os | |
| import urllib.request | |
| import gdown | |
| # Initialize Flask app and Bcrypt for password hashing | |
| app = Flask(__name__) | |
| app.secret_key = 'your_secret_key' | |
| bcrypt = Bcrypt(app) | |
| models_folder = "models" | |
| os.makedirs(models_folder, exist_ok=True) | |
| modelx2_file_id = "1Hvt3_t8S2W5CNYUCFgd2L_KitedAJEmH" | |
| trained_model_file_id = "1VCcCkj6jXBwiJAcHdAmHg_o6u32u4V7i" | |
| vqa_model_file_id = "1YlUXkLx2qQMFAcT0xZ2zfXU5xx2eRFEV" | |
| # Set upload folder and allowed extensions | |
| app.config['UPLOAD_FOLDER'] = 'static/uploads' | |
| app.config['UPSCALED_FOLDER'] = 'static/upscaled' | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| os.makedirs(app.config['UPSCALED_FOLDER'], exist_ok=True) | |
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
| # Preload models and processors for efficiency | |
| caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large") | |
| def download_model(file_id, model_path): | |
| if not os.path.exists(model_path): | |
| print(f"Downloading {model_path}...") | |
| url = f"https://drive.google.com/uc?export=download&id={file_id}" | |
| gdown.download(url, model_path, quiet=False) | |
| print(f"{model_path} downloaded successfully.") | |
| else: | |
| print(f"{model_path} already exists.") | |
| model_path = os.path.join(models_folder, "modelx2.ort") | |
| caption_model_path = os.path.join(models_folder, "trained_model.pkl") | |
| vqa_model_path = os.path.join(models_folder, "vqa_model.pkl") | |
| download_model(modelx2_file_id, model_path) | |
| download_model(trained_model_file_id, caption_model_path) | |
| download_model(vqa_model_file_id, vqa_model_path) | |
| # Helper functions | |
| def allowed_file(filename): | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def convert_pil_to_cv2(image): | |
| # pil_image = image.convert("RGB") | |
| open_cv_image = np.array(image) | |
| # RGB to BGR | |
| open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| return open_cv_image | |
| def pre_process(img: np.array) -> np.array: | |
| # H, W, C -> C, H, W | |
| img = np.transpose(img[:, :, 0:3], (2, 0, 1)) | |
| # C, H, W -> 1, C, H, W | |
| img = np.expand_dims(img, axis=0).astype(np.float32) | |
| return img | |
| def post_process(img: np.array) -> np.array: | |
| # 1, C, H, W -> C, H, W | |
| img = np.squeeze(img) | |
| # C, H, W -> H, W, C | |
| img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8) | |
| return img | |
| def inference(model_path: str, img_array: np.array) -> np.array: | |
| options = onnxruntime.SessionOptions() | |
| options.intra_op_num_threads = 1 | |
| options.inter_op_num_threads = 1 | |
| ort_session = onnxruntime.InferenceSession(model_path, options) | |
| ort_inputs = {ort_session.get_inputs()[0].name: img_array} | |
| ort_outs = ort_session.run(None, ort_inputs) | |
| return ort_outs[0] | |
| def upscale(image_path: str, model="modelx2"): | |
| pil_image = Image.open(image_path) | |
| img = convert_pil_to_cv2(pil_image) | |
| if img.ndim == 2: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| if img.shape[2] == 4: | |
| alpha = img[:, :, 3] # GRAY | |
| alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR | |
| alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR | |
| alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY | |
| img = img[:, :, 0:3] # BGR | |
| image_output = post_process(inference(model_path, pre_process(img))) # BGR | |
| image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA | |
| image_output[:, :, 3] = alpha_output | |
| elif img.shape[2] == 3: | |
| image_output = post_process(inference(model_path, pre_process(img))) | |
| image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB) | |
| return image_output | |
| # Main route | |
| def index(): | |
| return render_template('index.html', models=["modelx2", "modelx4"]) | |
| def upload_file(): | |
| if 'file' not in request.files: | |
| flash('Please upload an image.') | |
| return redirect(url_for('index')) | |
| file = request.files['file'] | |
| if file and allowed_file(file.filename): | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| similar_images = [] | |
| try: | |
| upscaled_img = upscale(filepath) | |
| upscaled_filename = f"upscaled_{filename}" | |
| upscaled_path = os.path.join(app.config['UPSCALED_FOLDER'], upscaled_filename) | |
| cv2.imwrite(upscaled_path, upscaled_img) | |
| image = Image.open(upscaled_path).convert("RGB") | |
| caption = generate_caption(image) | |
| results = DDGS().images( | |
| keywords=caption, | |
| region="wt-wt", | |
| safesearch="off", | |
| size=None, | |
| color="Monochrome", | |
| type_image=None, | |
| layout=None, | |
| license_image=None, | |
| max_results=100, | |
| ) | |
| for i in results: | |
| similar_images.append(i['image']) | |
| image_url = url_for('serve_upscaled_file', filename=upscaled_filename) | |
| return render_template('index.html',input_image_url=filepath, image_url=upscaled_path ,similar_images=similar_images, show_buttons=True) | |
| except Exception as e: | |
| flash(f"Upscaling failed: {e}") | |
| return redirect(url_for('index')) | |
| else: | |
| flash('Invalid file format. Please upload a PNG, JPG, or JPEG file.') | |
| return redirect(url_for('index')) | |
| def process_image(): | |
| image_url = os.path.basename(request.form.get('image_url')) | |
| filepath = os.path.join(app.config['UPSCALED_FOLDER'], image_url) | |
| print(filepath) | |
| image = Image.open(filepath).convert("RGB") | |
| if os.path.exists(filepath): | |
| if 'vqa' in request.form: | |
| question = request.form.get('question') | |
| if question: | |
| answer = answer_question(image, question) | |
| return render_template('index.html', image_url=filepath, answer=answer, show_buttons=True, question=question) | |
| else: | |
| flash("Please enter a question.") | |
| elif 'caption' in request.form: | |
| caption = generate_caption(image) | |
| return render_template('index.html', image_url=filepath, caption=caption, show_buttons=True) | |
| else: | |
| flash("File not found. Please re-upload the image.") | |
| return redirect(url_for('index')) | |
| def generate_caption(image): | |
| # Process the image and prepare it for input to the model | |
| inputs = caption_processor(images=image, return_tensors="pt") | |
| # Generate caption (model's output is token IDs) | |
| out = caption_model.generate(**inputs) | |
| # Decode the generated tokens back into text (the output is a tensor of token IDs) | |
| caption = caption_processor.decode(out[0], skip_special_tokens=True) | |
| return caption | |
| def answer_question(image, question): | |
| # Process the image and the question, prepare them for input to the model | |
| inputs = vqa_processor(images=image, text=question, return_tensors="pt") | |
| # Generate an answer (model's output is token IDs) | |
| out = vqa_model.generate(**inputs) | |
| # Decode the generated tokens back into the answer (again, output is token IDs) | |
| answer = vqa_processor.decode(out[0], skip_special_tokens=True) | |
| return answer | |
| def serve_uploaded_file(filename): | |
| return send_file(os.path.join(app.config['UPLOAD_FOLDER'], filename)) | |
| def serve_upscaled_file(filename): | |
| return send_file(os.path.join(app.config['UPSCALED_FOLDER'], filename)) | |
| # Run app | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |