fffiloni commited on
Commit
a7f1cad
·
verified ·
1 Parent(s): dc57498

Update app_wip.py

Browse files
Files changed (1) hide show
  1. app_wip.py +126 -0
app_wip.py CHANGED
@@ -19,3 +19,129 @@ snapshot_download(
19
  repo_id='JaydenLu666/Reward-Forcing-T2V-1.3B',
20
  local_dir='./checkpoints/Reward-Forcing-T2V-1.3B'
21
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  repo_id='JaydenLu666/Reward-Forcing-T2V-1.3B',
20
  local_dir='./checkpoints/Reward-Forcing-T2V-1.3B'
21
  )
22
+
23
+ import os
24
+ import uuid
25
+ import subprocess
26
+ from datetime import datetime
27
+
28
+ import gradio as gr
29
+
30
+ # === Chemins à adapter si besoin ===
31
+ CONFIG_PATH = "configs/reward_forcing.yaml"
32
+ CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
33
+
34
+ PROMPT_DIR = "prompts/gradio_inputs"
35
+ OUTPUT_ROOT = "videos/gradio_outputs"
36
+
37
+ os.makedirs(PROMPT_DIR, exist_ok=True)
38
+ os.makedirs(OUTPUT_ROOT, exist_ok=True)
39
+
40
+
41
+ def run_inference(prompt: str, duration: str, use_ema: bool):
42
+ """
43
+ 1. Écrit le prompt dans un fichier .txt
44
+ 2. Lance inference.py avec ce fichier comme --data_path
45
+ 3. Retourne le chemin de la première vidéo .mp4 générée + les logs
46
+ """
47
+ if not prompt or not prompt.strip():
48
+ raise gr.Error("Veuillez entrer un prompt texte 🙂")
49
+
50
+ # 1) On mappe la durée choisie → num_output_frames
51
+ if duration == "5s (21 frames)":
52
+ num_output_frames = 21
53
+ else: # "30s (120 frames)"
54
+ num_output_frames = 120
55
+
56
+ # 2) Fichier .txt temporaire pour le prompt
57
+ prompt_id = uuid.uuid4().hex[:8]
58
+ prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
59
+
60
+ with open(prompt_path, "w", encoding="utf-8") as f:
61
+ # TextDataset lit juste chaque ligne comme un prompt
62
+ f.write(prompt.strip() + "\n")
63
+
64
+ # 3) Dossier de sortie unique pour cette génération
65
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
66
+ output_folder = os.path.join(OUTPUT_ROOT, f"{ts}_{prompt_id}")
67
+ os.makedirs(output_folder, exist_ok=True)
68
+
69
+ # 4) Commande inference.py
70
+ cmd = [
71
+ "python",
72
+ "inference.py",
73
+ "--num_output_frames", str(num_output_frames),
74
+ "--config_path", CONFIG_PATH,
75
+ "--checkpoint_path", CHECKPOINT_PATH,
76
+ "--output_folder", output_folder,
77
+ "--data_path", prompt_path,
78
+ "--num_samples", "1",
79
+ ]
80
+ if use_ema:
81
+ cmd.append("--use_ema")
82
+
83
+ result = subprocess.run(
84
+ cmd,
85
+ stdout=subprocess.PIPE,
86
+ stderr=subprocess.STDOUT,
87
+ text=True,
88
+ )
89
+
90
+ logs = result.stdout
91
+
92
+ # 5) On récupère la première vidéo produite
93
+ mp4s = [f for f in os.listdir(output_folder) if f.lower().endswith(".mp4")]
94
+ if not mp4s:
95
+ raise gr.Error(
96
+ "Aucune vidéo trouvée dans le dossier de sortie.\n"
97
+ "Regarde les logs ci-dessous pour voir ce qui a coincé."
98
+ )
99
+
100
+ mp4s.sort()
101
+ video_path = os.path.join(output_folder, mp4s[0])
102
+ return video_path, logs
103
+
104
+
105
+ with gr.Blocks(title="Reward Forcing T2V Demo") as demo:
106
+ gr.Markdown(
107
+ """
108
+ # 🎬 Reward Forcing – Text-to-Video
109
+
110
+ Entrez un prompt texte, on génère un fichier `.txt` en interne
111
+ puis on lance `inference.py` avec ce fichier comme `--data_path`.
112
+ """
113
+ )
114
+
115
+ with gr.Row():
116
+ prompt_in = gr.Textbox(
117
+ label="Prompt",
118
+ placeholder="Ex: A cinematic shot of a spaceship flying above a neon city at night...",
119
+ lines=4,
120
+ )
121
+
122
+ with gr.Row():
123
+ duration = gr.Radio(
124
+ ["5s (21 frames)", "30s (120 frames)"],
125
+ value="5s (21 frames)",
126
+ label="Durée",
127
+ )
128
+ use_ema = gr.Checkbox(value=True, label="Utiliser les poids EMA (--use_ema)")
129
+
130
+ generate_btn = gr.Button("🚀 Générer la vidéo", variant="primary")
131
+
132
+ with gr.Row():
133
+ video_out = gr.Video(label="Vidéo générée")
134
+ logs_out = gr.Textbox(
135
+ label="Logs de inference.py",
136
+ lines=10,
137
+ interactive=False,
138
+ )
139
+
140
+ generate_btn.click(
141
+ fn=run_inference,
142
+ inputs=[prompt_in, duration, use_ema],
143
+ outputs=[video_out, logs_out],
144
+ )
145
+
146
+ if __name__ == "__main__":
147
+ demo.launch()