| import os |
| from uuid import uuid4 |
|
|
| import pytesseract |
| import torch |
| from PIL import Image |
| from controlnet_aux import HEDdetector |
| from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler |
| from flask import Flask, request, send_file |
| from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering |
| from transformers import pipeline |
|
|
| app = Flask('chatgpt-plugin-extras') |
|
|
|
|
| class VitGPT2: |
| def __init__(self, device): |
| print(f"Initializing VitGPT2 ImageCaptioning to {device}") |
| self.pipeline = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") |
|
|
| def inference(self, image_path): |
| captions = self.pipeline(image_path)[0]['generated_text'] |
| print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}") |
| return captions |
|
|
|
|
| class ImageCaptioning: |
| def __init__(self, device): |
| print(f"Initializing ImageCaptioning to {device}") |
| self.device = device |
| self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
| self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
| self.model = BlipForConditionalGeneration.from_pretrained( |
| "Salesforce/blip-image-captioning-large", torch_dtype=self.torch_dtype).to(self.device) |
|
|
| def inference(self, image_path): |
| inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype) |
| out = self.model.generate(**inputs) |
| captions = self.processor.decode(out[0], skip_special_tokens=True) |
| print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}") |
| return captions |
|
|
|
|
| class VQA: |
| def __init__(self, device): |
| print(f"Initializing Visual QA to {device}") |
| self.device = device |
| self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
| self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") |
| self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", |
| torch_dtype=self.torch_dtype).to(self.device) |
|
|
| def inference(self, image_path, question): |
| inputs = self.processor(Image.open(image_path), question, return_tensors="pt").to(self.device, self.torch_dtype) |
| out = self.model.generate(**inputs) |
| answers = self.processor.decode(out[0], skip_special_tokens=True) |
| print(f"\nProcessed Visual QA, Input Image: {image_path}, Output Text: {answers}") |
| return answers |
|
|
|
|
| class Image2Hed: |
| def __init__(self, device): |
| print("Initializing Image2Hed") |
| self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet') |
|
|
| def inference(self, inputs, output_filename): |
| output_path = os.path.join('data', output_filename) |
| image = Image.open(inputs) |
| hed = self.detector(image) |
| hed.save(output_path) |
| print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {output_path}") |
| return '/result/' + output_filename |
|
|
|
|
| class Image2Scribble: |
| def __init__(self, device): |
| print("Initializing Image2Scribble") |
| self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet') |
|
|
| def inference(self, inputs, output_filename): |
| output_path = os.path.join('data', output_filename) |
| image = Image.open(inputs) |
| hed = self.detector(image, scribble=True) |
| hed.save(output_path) |
| print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {output_path}") |
| return '/result/' + output_filename |
|
|
| class InstructPix2Pix: |
| def __init__(self, device): |
| print(f"Initializing InstructPix2Pix to {device}") |
| self.device = device |
| self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
| self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix", |
| safety_checker=None, |
| torch_dtype=self.torch_dtype).to(device) |
| self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config) |
|
|
| def inference(self, image_path, text, output_filename): |
| """Change style of image.""" |
| print("===>Starting InstructPix2Pix Inference") |
| original_image = Image.open(image_path) |
| image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0] |
| output_path = os.path.join('data', output_filename) |
| image.save(output_path) |
|
|
| print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, " |
| f"Output Image: {output_path}") |
| return '/result/' + output_path |
|
|
| @app.route('/result/<filename>') |
| def get_result(filename): |
| file_path = os.path.join('data', filename) |
| return send_file(file_path, mimetype='image/png') |
|
|
|
|
| ic = ImageCaptioning("cpu") |
| vqa = VQA("cpu") |
| i2h = Image2Hed("cpu") |
| i2s = Image2Scribble("cpu") |
| |
| |
|
|
| @app.route('/image2hed', methods=['POST']) |
| def imag2hed(): |
| file = request.files['file'] |
| filename = str(uuid4()) + '.png' |
| filepath = os.path.join('data', 'upload', filename) |
| file.save(filepath) |
| output_filename = str(uuid4()) + '.png' |
| result = i2h.inference(filepath, output_filename) |
| return result |
|
|
|
|
| @app.route('/image2Scribble', methods=['POST']) |
| def image2Scribble(): |
| file = request.files['file'] |
| filename = str(uuid4()) + '.png' |
| filepath = os.path.join('data', 'upload', filename) |
| file.save(filepath) |
| output_filename = str(uuid4()) + '.png' |
| result = i2s.inference(filepath, output_filename) |
| return result |
|
|
|
|
| @app.route('/image-captioning', methods=['POST']) |
| def image_caption(): |
| file = request.files['file'] |
| filename = str(uuid4()) + '.png' |
| filepath = os.path.join('data', 'upload', filename) |
| file.save(filepath) |
| |
| result2 = ic.inference(filepath) |
| return result2 |
|
|
|
|
| @app.route('/visual-qa', methods=['POST']) |
| def visual_qa(): |
| file = request.files['file'] |
| filename = str(uuid4()) + '.png' |
| filepath = os.path.join('data', 'upload', filename) |
| file.save(filepath) |
| question = request.args.get('q') |
| result = vqa.inference(filepath, question=question) |
| return result |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @app.route('/ocr', methods=['POST']) |
| def ocr(): |
| file = request.files['file'] |
| lang = request.args.get('lang') |
| if lang is None: |
| lang = 'chi_sim+eng' |
| filename = str(uuid4()) + '.png' |
| filepath = os.path.join('data', 'upload', filename) |
| file.save(filepath) |
| |
|
|
| |
| result = pytesseract.image_to_string(Image.open(filepath), lang=lang) |
| return result |
|
|
| if __name__ == '__main__': |
| app.run(host='0.0.0.0') |
|
|