| """ | |
| Created By: ishwor subedi | |
| Date: 2024-08-13 | |
| """ | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import BitsAndBytesConfig | |
| from src.services.image_caption.caption import ImageCaption | |
| from src.services.image_generation.image_generate import ImageGenerator | |
| class ImageProcessingPipeline: | |
| def __init__(self): | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| self.image_caption = ImageCaption(model_id="llava-hf/llava-1.5-7b-hf", quantization_config=quantization_config) | |
| self.image_generator = ImageGenerator() | |
| def generate_image(self, prompt, negative_prompt, style, use_negative_prompt, num_inference_steps, | |
| num_images_per_prompt, seed, width, height, guidance_scale, randomize_seed) -> Image: | |
| image = self.image_generator.generate_image(prompt=prompt, negative_prompt=negative_prompt, style=style, | |
| use_negative_prompt=use_negative_prompt, | |
| num_inference_steps=num_inference_steps, | |
| num_images_per_prompt=num_images_per_prompt, seed=seed, width=width, | |
| height=height, guidance_scale=guidance_scale, | |
| randomize_seed=randomize_seed) | |
| return image | |
| def generate_caption(self, image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, | |
| top_p): | |
| caption = self.image_caption.generate([], prompt, image, temperature, length_penalty, repetition_penalty, | |
| max_length, min_length, top_p) | |
| return caption | |