| import datetime |
| import os |
| from pathlib import Path |
| import sys |
| from flask import Flask, jsonify, request, send_file, abort |
| from flask_uploads import UploadSet, configure_uploads, IMAGES |
| from werkzeug.exceptions import default_exceptions |
| from werkzeug.exceptions import HTTPException, NotFound |
| import json |
| import torch |
| import time |
| import threading |
| import traceback |
| from PIL import Image |
| import numpy as np |
|
|
| PACKAGE_PARENT = '..' |
| WISE_DIR = '../wise/' |
| SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))) |
| sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) |
| sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, WISE_DIR))) |
|
|
|
|
|
|
| from parameter_optimization.parametric_styletransfer import single_optimize |
| from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG |
| from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to |
| from helpers import torch_to_np, np_to_torch |
| from effects import get_default_settings, MinimalPipelineEffect |
|
|
| class JSONExceptionHandler(object): |
|
|
| def __init__(self, app=None): |
| if app: |
| self.init_app(app) |
|
|
| def std_handler(self, error): |
| response = jsonify(message=error.message) |
| response.status_code = error.code if isinstance(error, HTTPException) else 500 |
| return response |
|
|
|
|
| def init_app(self, app): |
| self.app = app |
| self.register(HTTPException) |
| for code, v in default_exceptions.items(): |
| self.register(code) |
|
|
| def register(self, exception_or_code, handler=None): |
| self.app.errorhandler(exception_or_code)(handler or self.std_handler) |
|
|
|
|
|
|
| app = Flask(__name__) |
| handler = JSONExceptionHandler(app) |
|
|
| image_folder = 'img_received' |
| photos = UploadSet('photos', IMAGES) |
| app.config['UPLOADED_PHOTOS_DEST'] = image_folder |
| configure_uploads(app, photos) |
|
|
| class Args(object): |
| def __init__(self, initial_data): |
| for key in initial_data: |
| setattr(self, key, initial_data[key]) |
| def set_attributes(self, val_dict): |
| for key in val_dict: |
| setattr(self, key, val_dict[key]) |
|
|
| default_args = { |
| "output_image" : "output.jpg", |
| |
| "content_image": "", |
| "style_image": "", |
| "output_vp": "", |
| "iters": 500 |
| } |
|
|
|
|
| total_task_count = 0 |
|
|
| class NeuralOptimizer(): |
| def __init__(self, args) -> None: |
| self.cur_iteration = 0 |
| self.args = args |
|
|
| def optimize(self): |
| base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}" |
| os.makedirs(base_dir) |
|
|
| content = Image.open(self.args.content_image) |
| style = Image.open(self.args.style_image) |
|
|
| def set_iter(iter): |
| self.cur_iteration=iter |
|
|
| effect, preset, _ = get_default_settings("minimal_pipeline") |
| effect.enable_checkpoints() |
|
|
| reference = strotss(pil_resize_long_edge_to(content, 1024), |
| pil_resize_long_edge_to(style, 1024), content_weight=16.0, |
| device=torch.device("cuda"), space="uniform") |
|
|
| ref_save_path = os.path.join(base_dir, "reference.jpg") |
| resize_to = 720 |
| reference = pil_resize_long_edge_to(reference, resize_to) |
| reference.save(ref_save_path) |
| ST_CONFIG["n_iterations"] = self.args.iters |
| vp, content_img_cuda = single_optimize(effect, preset, "l1", self.args.content_image, str(ref_save_path), |
| write_video=False, base_dir=base_dir, |
| iter_callback=set_iter) |
|
|
| output = Image.fromarray(torch_to_np(content_img_cuda.detach().cpu() * 255.0).astype(np.uint8)) |
| output.save(self.args.output_image) |
| |
| |
| np.savez_compressed(self.args.output_vp, vp=vp.detach().cpu().numpy()) |
|
|
| |
|
|
| class StyleTask: |
| def __init__(self, task_id, style_filename, content_filename): |
| self.content_filename=content_filename |
| self.style_filename=style_filename |
|
|
| self.status = "queued" |
| self.task_id = task_id |
| self.error_msg = "" |
| self.output_filename = content_filename.split(".")[0] + "_output.jpg" |
| self.vp_output_filename = content_filename.split(".")[0] + "_output.npz" |
|
|
| |
| |
| |
|
|
| self.neural_optimizer = NeuralOptimizer(Args(default_args)) |
| |
| def start(self): |
| self.neural_optimizer.args.set_attributes(default_args) |
|
|
| self.neural_optimizer.args.style_image = os.path.join(image_folder, self.style_filename) |
| self.neural_optimizer.args.content_image = os.path.join(image_folder, self.content_filename) |
| self.neural_optimizer.args.output_image = os.path.join(image_folder, self.output_filename) |
| self.neural_optimizer.args.output_vp = os.path.join(image_folder, self.vp_output_filename) |
| |
| thread = threading.Thread(target=self.run, args=()) |
| thread.daemon = True |
| thread.start() |
|
|
| def run(self): |
| self.status = "running" |
| try: |
| self.neural_optimizer.optimize() |
| except Exception as e: |
| print("Error in task %d :"%(self.task_id), str(e)) |
| traceback.print_exc() |
|
|
| self.status = "error" |
| self.error_msg = str(e) |
| return |
|
|
| self.status = "finished" |
| print("finished styling task: " + str(self.task_id)) |
|
|
| class StylerQueue: |
| queued_tasks = [] |
| finished_tasks = [] |
| running_task = None |
|
|
| def __init__(self): |
| thread = threading.Thread(target=self.status_checker, args=()) |
| thread.daemon = True |
| thread.start() |
|
|
| def queue_task(self, *args): |
| global total_task_count |
| total_task_count += 1 |
| task_id = abs(hash(str(time.time()))) |
| print("queued task num. ", total_task_count, "with ID", task_id) |
| task = StyleTask(task_id, *args) |
| self.queued_tasks.append(task) |
|
|
| return task_id |
|
|
| def get_task(self, task_id): |
| if self.running_task is not None and self.running_task.task_id == task_id: |
| return self.running_task |
| task = next((task for task in self.queued_tasks + self.finished_tasks if task.task_id == task_id), None) |
| return task |
|
|
| def status_checker(self): |
| while True: |
| time.sleep(0.3) |
|
|
| if self.running_task is None: |
| if len(self.queued_tasks) > 0: |
| self.running_task = self.queued_tasks[0] |
| self.running_task.start() |
| self.queued_tasks = self.queued_tasks[1:] |
| elif self.running_task.status == "finished" or self.running_task.status == "error": |
| self.finished_tasks.append(self.running_task) |
| if len(self.queued_tasks) > 0: |
| self.running_task = self.queued_tasks[0] |
| self.running_task.start() |
| self.queued_tasks = self.queued_tasks[1:] |
| else: |
| self.running_task = None |
|
|
| styler_queue = StylerQueue() |
|
|
|
|
| @app.route('/upload', methods=['POST']) |
| def upload(): |
| if 'style-image' in request.files and \ |
| 'content-image' in request.files: |
|
|
| style_filename = photos.save(request.files['style-image']) |
| content_filename = photos.save(request.files['content-image']) |
|
|
| job_id = styler_queue.queue_task(style_filename, content_filename) |
| print('added new stylization task', style_filename, content_filename) |
|
|
| return jsonify({"task_id": job_id}) |
| abort(jsonify(message="request needs style, content image"), 400) |
|
|
| @app.route('/get_status') |
| def get_status(): |
| task_id = int(request.args.get("task_id")) |
| task = styler_queue.get_task(task_id) |
|
|
| if task is None: |
| abort(jsonify(message="task with id %d not found"%task_id), 400) |
|
|
| status = { |
| "status": task.status, |
| "msg": task.error_msg |
| } |
|
|
| if task.status == "running": |
| if isinstance(task, StyleTask): |
| status["progress"] = float(task.neural_optimizer.cur_iteration) / float(default_args["iters"]) |
|
|
| return jsonify(status) |
|
|
| @app.route('/queue_length') |
| def get_queue_length(): |
| tasks = len(styler_queue.queued_tasks) |
| if styler_queue.running_task is not None: |
| tasks += 1 |
|
|
| status = { |
| "length": tasks |
| } |
|
|
| return jsonify(status) |
|
|
|
|
| @app.route('/get_image') |
| def get_image(): |
| task_id = int(request.args.get("task_id")) |
| task = styler_queue.get_task(task_id) |
|
|
| if task is None: |
| abort(jsonify(message="task with id %d not found"%task_id), 400) |
|
|
| if task.status != "finished": |
| abort(jsonify(message="task with id %d not in finished state"%task_id), 400) |
|
|
| return send_file(os.path.join(image_folder, task.output_filename), mimetype='image/jpg') |
|
|
| @app.route('/get_vp') |
| def get_vp(): |
| task_id = int(request.args.get("task_id")) |
| task = styler_queue.get_task(task_id) |
|
|
| if task is None: |
| abort(jsonify(message="task with id %d not found"%task_id), 400) |
|
|
| if task.status != "finished": |
| abort(jsonify(message="task with id %d not in finished state"%task_id), 400) |
|
|
| return send_file(os.path.join(image_folder, task.vp_output_filename), mimetype='application/zip') |
|
|
|
|
| if __name__ == '__main__': |
| app.run(debug=False, host="0.0.0.0",port=8600) |
|
|