Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.utils import make_grid | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import os | |
| from PIL import Image | |
| import gradio as gr | |
| # --- 0. PARAMÈTRES ET CONFIGURATION --- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| LATENT_DIM = 128 | |
| YOUR_USERNAME = "Clemylia" | |
| # --- NOUVEAU : DÉFINITION DES MODÈLES DISPONIBLES --- | |
| MODEL_CONFIGS = { | |
| # Clé (affichée dans le menu Gradio) : Valeurs (utilisées pour le chargement HF) | |
| "Forza-ia Large (1M)": { | |
| "repo_id": f"{YOUR_USERNAME}/Forza-ia-large-1M", | |
| "file_name": "forza_ia_vae.pth" | |
| }, | |
| "Forza-ia Base": { | |
| # J'assume le nom du dépôt et du fichier pour le second modèle. | |
| # Adapte si 'forza-ia' a des noms de fichiers différents. | |
| "repo_id": f"{YOUR_USERNAME}/forza-ia", | |
| "file_name": "forza_ia_vae.pth" | |
| } | |
| } | |
| # Variable globale pour stocker les modèles déjà chargés (Mise en Cache) | |
| # Clé: 'repo_id', Valeur: instance VAE chargée | |
| MODEL_CACHE = {} | |
| print(f"Appareil utilisé : {DEVICE}") | |
| # --- 1. DÉFINITION DE L'ARCHITECTURE VAE --- | |
| # L'architecture doit être compatible avec les deux modèles si elle est la même pour les deux. | |
| # Si 'Forza-ia Base' a une architecture différente, il faudra l'adapter ! | |
| class VAE(nn.Module): | |
| def __init__(self, latent_dim=LATENT_DIM): | |
| super(VAE, self).__init__() | |
| self.latent_dim = latent_dim | |
| # ENCODEUR | |
| self.encoder = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(), | |
| nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(), | |
| nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(), | |
| nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.ReLU(), | |
| nn.Flatten() | |
| ) | |
| self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim) | |
| self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim) | |
| # DÉCODEUR | |
| self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4) | |
| self.decoder = nn.Sequential( | |
| nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(), | |
| nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(), | |
| nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), nn.ReLU(), | |
| nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), | |
| nn.Tanh() | |
| ) | |
| def decode(self, z): | |
| h = self.decoder_input(z) | |
| h = h.view(-1, 256, 4, 4) | |
| return self.decoder(h) | |
| # --- 2. FONCTION DE CHARGEMENT AVEC MISE EN CACHE --- | |
| def load_model(repo_id, file_name): | |
| """ | |
| Charge le modèle ou le récupère depuis le cache. | |
| """ | |
| # Vérifie si le modèle est déjà dans le cache | |
| if repo_id in MODEL_CACHE: | |
| print(f"Modèle {repo_id} récupéré du cache.") | |
| return MODEL_CACHE[repo_id] | |
| print(f"Chargement du modèle : {repo_id} / {file_name}") | |
| # TÉLÉCHARGEMENT DES POIDS | |
| try: | |
| model_path = hf_hub_download(repo_id=repo_id, filename=file_name) | |
| except Exception as e: | |
| raise gr.Error(f"Échec de chargement des poids pour {repo_id}. Erreur : {e}") | |
| # CHARGEMENT DU MODÈLE | |
| try: | |
| model = VAE().to(DEVICE) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| # Ajoute le modèle au cache avant de le retourner | |
| MODEL_CACHE[repo_id] = model | |
| print(f"Modèle {repo_id} chargé et mis en cache.") | |
| return model | |
| except Exception as e: | |
| raise gr.Error(f"Erreur lors de l'initialisation ou du chargement de l'état du modèle {repo_id}. Erreur : {e}") | |
| # --- 3. FONCTION DE GÉNÉRATION POUR GRADIO (MISE À JOUR) --- | |
| def generate_images(model_choice: str, num_images: int) -> Image.Image: | |
| """ | |
| Génère une grille d'images à partir du VAE sélectionné. | |
| Args: | |
| model_choice: La clé du modèle sélectionné par l'utilisateur. | |
| num_images: Le nombre d'images aléatoires à générer. | |
| Returns: | |
| Une image PIL contenant la grille des images générées. | |
| """ | |
| # Récupérer les paramètres HF à partir du choix utilisateur | |
| try: | |
| config = MODEL_CONFIGS[model_choice] | |
| repo_id = config["repo_id"] | |
| file_name = config["file_name"] | |
| except KeyError: | |
| raise gr.Error("Sélection de modèle invalide.") | |
| # 1. Charger/Récupérer le modèle | |
| current_model = load_model(repo_id, file_name) | |
| # 2. Logique de Génération | |
| root = int(np.sqrt(num_images)) | |
| num_samples_to_use = root * root | |
| with torch.no_grad(): | |
| # Échantillonnage | |
| sample = torch.randn(num_samples_to_use, LATENT_DIM).to(DEVICE) | |
| # Décodage | |
| generated_images = current_model.decode(sample).cpu() | |
| # Re-normalisation et Conversion | |
| generated_images = (generated_images + 1) / 2 | |
| grid = make_grid(generated_images, nrow=root) | |
| np_grid = grid.permute(1, 2, 0).numpy() | |
| pil_image = Image.fromarray((np_grid * 255).astype(np.uint8)) | |
| return pil_image | |
| # --- 4. INTERFACE GRADIO (MISE À JOUR) --- | |
| # 1. Composant de sélection du modèle (le nouveau) | |
| model_dropdown = gr.Dropdown( | |
| label="Sélectionnez le Modèle Forza-ia VAE", | |
| choices=list(MODEL_CONFIGS.keys()), # Utilise les clés du dictionnaire pour les options | |
| value=list(MODEL_CONFIGS.keys())[0], # Sélectionne le premier modèle par défaut | |
| interactive=True | |
| ) | |
| # 2. Composant d'entrée du nombre d'images (comme avant) | |
| num_images_slider = gr.Slider( | |
| minimum=1, | |
| maximum=25, | |
| value=9, | |
| step=1, | |
| label="Nombre d'images à générer", | |
| ) | |
| # 3. Composant de sortie (comme avant) | |
| output_component = gr.Image( | |
| label="Génération Aléatoire VAE", | |
| width=512, | |
| height=512 | |
| ) | |
| # Création de l'interface | |
| title = "🎨 Forza-ia VAE - Double Modèle Générateur" | |
| description = "Choisissez entre le modèle 'Forza-ia Large (1M)' et le modèle 'Forza-ia Base' pour générer des images de dessins d'enfants à partir de l'espace latent." | |
| iface = gr.Interface( | |
| fn=generate_images, | |
| # Les entrées sont maintenant le menu déroulant ET le curseur | |
| inputs=[model_dropdown, num_images_slider], | |
| outputs=output_component, | |
| title=title, | |
| description=description, | |
| allow_flagging="never" | |
| ) | |
| # Lancement de l'application | |
| if __name__ == "__main__": | |
| iface.launch() |