danhtran2mind's picture
Upload 84 files
7946a9d verified
import os
import sys
import torch
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',
'src', 'ghibli_stable_diffusion_synthesis',
'inference')))
from full_finetuning import inference_process as full_finetuning_inference
from lora import inference_process as lora_inference
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed,
random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id,
lora_scale, config_path, device, dtype):
batch_size = 1
if random_seed:
seed = torch.randint(0, 4294967295, (1,)).item()
try:
model_id = finetune_model_id
if not use_lora:
pil_image = full_finetuning_inference(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
batch_size=batch_size,
seed=seed,
config_path=config_path,
model_id=model_id,
device=device,
dtype=dtype
)
else:
model_id = lora_model_id
pil_image = lora_inference(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
batch_size=batch_size,
seed=seed,
lora_scale=lora_scale,
config_path=config_path,
model_id=model_id,
# base_model_id=base_model_id,
device=device,
dtype=dtype
)
return pil_image, f"Generated image successfully! Seed used: {seed}"
except Exception as e:
return None, f"Failed to generate image: {e}"