""" Created By: ishwor subedi Date: 2024-08-13 """ import numpy as np import torch from PIL import Image from transformers import BitsAndBytesConfig from src.services.image_caption.caption import ImageCaption from src.services.image_generation.image_generate import ImageGenerator class ImageProcessingPipeline: def __init__(self): quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 ) self.image_caption = ImageCaption(model_id="llava-hf/llava-1.5-7b-hf", quantization_config=quantization_config) self.image_generator = ImageGenerator() def generate_image(self, prompt, negative_prompt, style, use_negative_prompt, num_inference_steps, num_images_per_prompt, seed, width, height, guidance_scale, randomize_seed) -> Image: image = self.image_generator.generate_image(prompt=prompt, negative_prompt=negative_prompt, style=style, use_negative_prompt=use_negative_prompt, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, seed=seed, width=width, height=height, guidance_scale=guidance_scale, randomize_seed=randomize_seed) return image def generate_caption(self, image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p): caption = self.image_caption.generate([], prompt, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p) return caption