Forza-demo / app.py
Clemylia's picture
Create app.py
407179b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
from huggingface_hub import hf_hub_download
import numpy as np
import os
from PIL import Image
import gradio as gr
# --- 0. PARAMÈTRES ET CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LATENT_DIM = 128
YOUR_USERNAME = "Clemylia"
# --- NOUVEAU : DÉFINITION DES MODÈLES DISPONIBLES ---
MODEL_CONFIGS = {
# Clé (affichée dans le menu Gradio) : Valeurs (utilisées pour le chargement HF)
"Forza-ia Large (1M)": {
"repo_id": f"{YOUR_USERNAME}/Forza-ia-large-1M",
"file_name": "forza_ia_vae.pth"
},
"Forza-ia Base": {
# J'assume le nom du dépôt et du fichier pour le second modèle.
# Adapte si 'forza-ia' a des noms de fichiers différents.
"repo_id": f"{YOUR_USERNAME}/forza-ia",
"file_name": "forza_ia_vae.pth"
}
}
# Variable globale pour stocker les modèles déjà chargés (Mise en Cache)
# Clé: 'repo_id', Valeur: instance VAE chargée
MODEL_CACHE = {}
print(f"Appareil utilisé : {DEVICE}")
# --- 1. DÉFINITION DE L'ARCHITECTURE VAE ---
# L'architecture doit être compatible avec les deux modèles si elle est la même pour les deux.
# Si 'Forza-ia Base' a une architecture différente, il faudra l'adapter !
class VAE(nn.Module):
def __init__(self, latent_dim=LATENT_DIM):
super(VAE, self).__init__()
self.latent_dim = latent_dim
# ENCODEUR
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
# DÉCODEUR
self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def decode(self, z):
h = self.decoder_input(z)
h = h.view(-1, 256, 4, 4)
return self.decoder(h)
# --- 2. FONCTION DE CHARGEMENT AVEC MISE EN CACHE ---
def load_model(repo_id, file_name):
"""
Charge le modèle ou le récupère depuis le cache.
"""
# Vérifie si le modèle est déjà dans le cache
if repo_id in MODEL_CACHE:
print(f"Modèle {repo_id} récupéré du cache.")
return MODEL_CACHE[repo_id]
print(f"Chargement du modèle : {repo_id} / {file_name}")
# TÉLÉCHARGEMENT DES POIDS
try:
model_path = hf_hub_download(repo_id=repo_id, filename=file_name)
except Exception as e:
raise gr.Error(f"Échec de chargement des poids pour {repo_id}. Erreur : {e}")
# CHARGEMENT DU MODÈLE
try:
model = VAE().to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
# Ajoute le modèle au cache avant de le retourner
MODEL_CACHE[repo_id] = model
print(f"Modèle {repo_id} chargé et mis en cache.")
return model
except Exception as e:
raise gr.Error(f"Erreur lors de l'initialisation ou du chargement de l'état du modèle {repo_id}. Erreur : {e}")
# --- 3. FONCTION DE GÉNÉRATION POUR GRADIO (MISE À JOUR) ---
def generate_images(model_choice: str, num_images: int) -> Image.Image:
"""
Génère une grille d'images à partir du VAE sélectionné.
Args:
model_choice: La clé du modèle sélectionné par l'utilisateur.
num_images: Le nombre d'images aléatoires à générer.
Returns:
Une image PIL contenant la grille des images générées.
"""
# Récupérer les paramètres HF à partir du choix utilisateur
try:
config = MODEL_CONFIGS[model_choice]
repo_id = config["repo_id"]
file_name = config["file_name"]
except KeyError:
raise gr.Error("Sélection de modèle invalide.")
# 1. Charger/Récupérer le modèle
current_model = load_model(repo_id, file_name)
# 2. Logique de Génération
root = int(np.sqrt(num_images))
num_samples_to_use = root * root
with torch.no_grad():
# Échantillonnage
sample = torch.randn(num_samples_to_use, LATENT_DIM).to(DEVICE)
# Décodage
generated_images = current_model.decode(sample).cpu()
# Re-normalisation et Conversion
generated_images = (generated_images + 1) / 2
grid = make_grid(generated_images, nrow=root)
np_grid = grid.permute(1, 2, 0).numpy()
pil_image = Image.fromarray((np_grid * 255).astype(np.uint8))
return pil_image
# --- 4. INTERFACE GRADIO (MISE À JOUR) ---
# 1. Composant de sélection du modèle (le nouveau)
model_dropdown = gr.Dropdown(
label="Sélectionnez le Modèle Forza-ia VAE",
choices=list(MODEL_CONFIGS.keys()), # Utilise les clés du dictionnaire pour les options
value=list(MODEL_CONFIGS.keys())[0], # Sélectionne le premier modèle par défaut
interactive=True
)
# 2. Composant d'entrée du nombre d'images (comme avant)
num_images_slider = gr.Slider(
minimum=1,
maximum=25,
value=9,
step=1,
label="Nombre d'images à générer",
)
# 3. Composant de sortie (comme avant)
output_component = gr.Image(
label="Génération Aléatoire VAE",
width=512,
height=512
)
# Création de l'interface
title = "🎨 Forza-ia VAE - Double Modèle Générateur"
description = "Choisissez entre le modèle 'Forza-ia Large (1M)' et le modèle 'Forza-ia Base' pour générer des images de dessins d'enfants à partir de l'espace latent."
iface = gr.Interface(
fn=generate_images,
# Les entrées sont maintenant le menu déroulant ET le curseur
inputs=[model_dropdown, num_images_slider],
outputs=output_component,
title=title,
description=description,
allow_flagging="never"
)
# Lancement de l'application
if __name__ == "__main__":
iface.launch()