|
|
import torch |
|
|
from config import Config |
|
|
|
|
|
class Generator: |
|
|
def __init__(self, model_handler): |
|
|
self.mh = model_handler |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
user_prompt, |
|
|
negative_prompt="", |
|
|
guidance_scale=1.2, |
|
|
num_inference_steps=8, |
|
|
seed=-1 |
|
|
): |
|
|
|
|
|
if not user_prompt.strip(): |
|
|
|
|
|
final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful landscape, pixel art" |
|
|
else: |
|
|
final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}" |
|
|
|
|
|
print(f"Prompt: {final_prompt}") |
|
|
|
|
|
|
|
|
if seed == -1 or seed is None: |
|
|
seed = torch.Generator().seed() |
|
|
generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed)) |
|
|
print(f"Using seed: {seed}") |
|
|
|
|
|
|
|
|
print("Running pipeline...") |
|
|
result = self.mh.pipeline( |
|
|
prompt=final_prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
generator=generator, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
clip_skip=2, |
|
|
).images[0] |
|
|
|
|
|
return result |