|
|
import torch |
|
|
from config import Config |
|
|
|
|
|
class Generator: |
|
|
def __init__(self, model_handler): |
|
|
self.mh = model_handler |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
user_prompt, |
|
|
negative_prompt="", |
|
|
aspect_ratio="1:1", |
|
|
guidance_scale=1.6, |
|
|
num_inference_steps=8, |
|
|
seed=-1 |
|
|
): |
|
|
|
|
|
if not user_prompt.strip(): |
|
|
|
|
|
final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful landscape, pixel art" |
|
|
else: |
|
|
final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}" |
|
|
|
|
|
print(f"Prompt: {final_prompt}") |
|
|
|
|
|
|
|
|
width, height = Config.ASPECT_RATIOS.get(aspect_ratio, Config.ASPECT_RATIOS[Config.DEFAULT_ASPECT_RATIO]) |
|
|
print(f"Aspect Ratio: {aspect_ratio} ({width}x{height})") |
|
|
|
|
|
|
|
|
if seed == -1 or seed is None: |
|
|
seed = torch.Generator().seed() |
|
|
generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed)) |
|
|
print(f"Using seed: {seed}") |
|
|
|
|
|
|
|
|
print("Running pipeline...") |
|
|
result = self.mh.pipeline( |
|
|
prompt=final_prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
generator=generator, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
clip_skip=1, |
|
|
).images[0] |
|
|
|
|
|
return result |
|
|
|