File size: 5,551 Bytes
cb2a484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
generate.py — Script principal de génération vidéo (Texte → Vidéo)
Modèle utilisé : Wan 2.1 T2V 1.3B (Hugging Face Diffusers)

Usage :
    python generate.py --prompt "votre prompt" --output ../outputs/ma_video.mp4
"""

import argparse
import os
import sys


def _select_torch_dtype(torch_module):
    if not torch_module.cuda.is_available():
        return torch_module.float32

    if torch_module.cuda.is_bf16_supported():
        return torch_module.bfloat16

    return torch_module.float16

def generate_video(
    prompt: str,
    output_path: str,
    negative_prompt: str = "déformé, moche, flou, mauvaise qualité, artefacts, texte, filigrane",
    num_frames: int = 24,
    num_inference_steps: int = 25,
    height: int = 480,
    width: int = 832,
) -> str:
    """
    Génère une vidéo à partir d'un prompt texte et la sauvegarde dans le chemin spécifié.

    Args:
        prompt (str)             : Description textuelle de la vidéo à générer.
        output_path (str)        : Chemin de sauvegarde de la vidéo (ex: outputs/video.mp4).
        negative_prompt (str)    : Ce que l'on ne veut PAS voir dans la vidéo.
        num_frames (int)         : Nombre d'images à générer (24 ≈ 1 seconde à 24 fps).
        num_inference_steps (int): Nombre d'étapes de diffusion (plus = meilleure qualité, plus lent).
        height (int)             : Hauteur de la vidéo en pixels.
        width (int)              : Largeur de la vidéo en pixels.

    Returns:
        str: Le chemin de la vidéo générée, ou une chaîne vide en cas d'erreur.
    """
    try:
        import torch

        # Importation différée pour éviter les erreurs si les dépendances ne sont pas installées
        from diffusers import AutoencoderKLWan, WanPipeline
        from diffusers.utils import export_to_video

        print(f"[INFO] Chargement du modèle Wan 2.1 T2V 1.3B...")
        print(f"[INFO] (Le premier chargement peut prendre du temps — téléchargement des poids)")

        model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
        dtype = _select_torch_dtype(torch)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"[INFO] Appareil détecté : {device} | dtype : {dtype}")

        # Chargement du VAE (encodeur/décodeur vidéo)
        vae = AutoencoderKLWan.from_pretrained(
            model_id, subfolder="vae", torch_dtype=dtype
        )

        # Chargement du pipeline principal
        pipe = WanPipeline.from_pretrained(
            model_id, vae=vae, torch_dtype=dtype
        )

        if torch.cuda.is_available():
            # Optimisation mémoire : décharge les parties inactives du modèle vers la RAM
            pipe.enable_model_cpu_offload()
        else:
            pipe = pipe.to(device)

        print(f"[INFO] Modèle chargé. Génération de la vidéo...")
        print(f"[INFO] Prompt : {prompt}")

        # Génération des frames vidéo
        output = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=num_frames,
            num_inference_steps=num_inference_steps,
        )

        # Création du dossier de sortie si nécessaire
        os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)

        # Export en fichier vidéo MP4
        video_path = export_to_video(output.frames[0], output_path, fps=16)
        print(f"[SUCCESS] Vidéo sauvegardée : {video_path}")
        return video_path

    except ImportError as e:
        print(f"[ERREUR] Dépendance manquante : {e}")
        print("[AIDE] Installez les dépendances avec : pip install -r requirements.txt")
        return ""
    except OSError as e:
        print(f"[ERREUR] Problème d'environnement Python ou de bibliothèque native : {e}")
        print("[AIDE] Utilisez de préférence Python 3.10, 3.11 ou 3.12 avec une version compatible de PyTorch.")
        return ""
    except Exception as e:
        print(f"[ERREUR] Erreur lors de la génération : {e}")
        return ""


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Générateur de vidéo IA (Texte → Vidéo) avec Wan 2.1"
    )
    parser.add_argument(
        "--prompt", type=str, required=True,
        help="Description textuelle de la vidéo à générer"
    )
    parser.add_argument(
        "--output", type=str, default="../outputs/generated_video.mp4",
        help="Chemin de sortie pour la vidéo (défaut: ../outputs/generated_video.mp4)"
    )
    parser.add_argument(
        "--negative_prompt", type=str,
        default="déformé, moche, flou, mauvaise qualité, artefacts, texte, filigrane",
        help="Ce que l'on ne veut PAS voir dans la vidéo"
    )
    parser.add_argument(
        "--num_frames", type=int, default=24,
        help="Nombre d'images à générer (défaut: 24 ≈ 1.5 secondes)"
    )
    parser.add_argument(
        "--steps", type=int, default=25,
        help="Nombre d'étapes de diffusion (défaut: 25)"
    )

    args = parser.parse_args()

    result = generate_video(
        prompt=args.prompt,
        output_path=args.output,
        negative_prompt=args.negative_prompt,
        num_frames=args.num_frames,
        num_inference_steps=args.steps,
    )

    if result:
        print(f"\n✅ Génération réussie ! Vidéo disponible ici : {result}")
    else:
        print("\n❌ La génération a échoué. Consultez les messages d'erreur ci-dessus.")
        sys.exit(1)