File size: 3,283 Bytes
307e821
 
 
e781457
 
7e6bfa5
e781457
 
 
 
 
 
 
7e6bfa5
e781457
 
 
 
7e6bfa5
e781457
307e821
7e6bfa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307e821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e781457
 
307e821
 
 
e781457
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
from huggingface_hub import hf_hub_download
from safetensors import safe_open

# Scarica il modello .safetensors dalla tua model card
model_path = hf_hub_download(
    repo_id="PietroC01/ImgEnhancerModels",
    filename="juggernaut_reborn.safetensors"
)

# Carica il modello Stable Diffusion
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  # Modello base compatibile
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    use_safetensors=True,
    safety_checker=None
)

# Determina se è un LoRA o un modello completo
try:
    # Opzione 1: Carica come LoRA
    pipe.load_lora_weights(model_path)
except Exception as lora_error:
    try:
        # Opzione 2: Carica direttamente nell'UNet
        with safe_open(model_path, framework="pt", device="cpu") as f:
            unet_params = {k: f.get_tensor(k) for k in f.keys() if k.startswith("unet.")}
            
        # Rimuovi il prefisso "unet." dalle chiavi se presente
        clean_unet_params = {}
        for k, v in unet_params.items():
            if k.startswith("unet."):
                clean_unet_params[k[5:]] = v  # Rimuovi il prefisso "unet."
            else:
                clean_unet_params[k] = v
                
        # Carica i pesi nell'UNet
        if clean_unet_params:
            pipe.unet.load_state_dict(clean_unet_params, strict=False)
        else:
            # Opzione 3: Prova a caricare come checkpoint completo
            pipe = StableDiffusionImg2ImgPipeline.from_single_file(
                model_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                use_safetensors=True,
                load_safety_checker=False
            )
    except Exception as e:
        print(f"Errore nel caricamento del modello: {e}")
        # Fallback: usa il modello base
        print("Utilizzo del modello base come fallback")

# Sposta il modello su GPU se disponibile
pipe.to("cuda" if torch.cuda.is_available() else "cpu")

def enhance_image(image, prompt, negative_prompt, cfg_scale, denoising_strength):
    """
    Migliora un'immagine utilizzando Stable Diffusion Img2Img.
    """
    image = Image.open(image).convert("RGB")
    enhanced_image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=image,
        strength=denoising_strength,
        guidance_scale=cfg_scale
    ).images[0]
    
    return enhanced_image

# Interfaccia utente con Gradio
demo = gr.Interface(
    fn=enhance_image,
    inputs=[
        gr.Image(type="filepath", label="Carica l'immagine:"),
        gr.Textbox(label="Prompt", value="highly detailed, ultra high resolution"),
        gr.Textbox(label="Negative Prompt", value="low quality, blurry, artifacts"),
        gr.Slider(1, 20, value=7, label="CFG Scale"),
        gr.Slider(0.1, 1.0, value=0.35, label="Denoising Strength")
    ],
    outputs=gr.Image(type="pil", label="Immagine migliorata"),
    title="Image Enhancer con Juggernaut Reborn",
    description="Carica un'immagine e usa il modello Juggernaut Reborn per migliorarla."
)

if __name__ == "__main__":
    demo.launch()