learnable-ai / src /pipeline /image_processing_pipeline.py
ishworrsubedii's picture
Added new features and improved code formatting:
32a0eda
"""
Created By: ishwor subedi
Date: 2024-08-13
"""
import numpy as np
import torch
from PIL import Image
from transformers import BitsAndBytesConfig
from src.services.image_caption.caption import ImageCaption
from src.services.image_generation.image_generate import ImageGenerator
class ImageProcessingPipeline:
def __init__(self):
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
self.image_caption = ImageCaption(model_id="llava-hf/llava-1.5-7b-hf", quantization_config=quantization_config)
self.image_generator = ImageGenerator()
def generate_image(self, prompt, negative_prompt, style, use_negative_prompt, num_inference_steps,
num_images_per_prompt, seed, width, height, guidance_scale, randomize_seed) -> Image:
image = self.image_generator.generate_image(prompt=prompt, negative_prompt=negative_prompt, style=style,
use_negative_prompt=use_negative_prompt,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt, seed=seed, width=width,
height=height, guidance_scale=guidance_scale,
randomize_seed=randomize_seed)
return image
def generate_caption(self, image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length,
top_p):
caption = self.image_caption.generate([], prompt, image, temperature, length_penalty, repetition_penalty,
max_length, min_length, top_p)
return caption