File size: 12,325 Bytes
aab6df4
 
 
 
f0f68cf
 
 
5a4f54c
f0f68cf
5a4f54c
 
 
f0f68cf
5a4f54c
 
 
 
 
aab6df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import gradio as gr
from typing import Optional, List, Dict, Any
import numpy as np
import time
import sys
from pathlib import Path

# CRITICAL FIX: Ensure local modules are prioritized over system packages
# This prevents ImportError when a 'utils' package exists in site-packages
current_dir = Path(__file__).resolve().parent
if str(current_dir) not in sys.path:
    sys.path.insert(0, str(current_dir))

# Clear any cached utils module to force reimport from local file
if 'utils' in sys.modules:
    del sys.modules['utils']

# Now import from the local utils.py file
from utils import WANVideoGenerator, LoRAManager, NSFWChecker
from config import MODEL_CONFIGS, AVAILABLE_LORAS, NSFW_CONFIG

# Initialize core components
generator = WANVideoGenerator()
lora_manager = LoRAManager()
nsfw_checker = NSFWChecker()

def generate_video(
    image: np.ndarray,
    prompt: str,
    selected_model: str,
    enabled_loras: List[str],
    enable_nsfw: bool,
    video_length: int,
    resolution: str,
    progress=gr.Progress()
) -> tuple[str, str, Dict[str, Any]]:
    """
    Main video generation function with WAN-scale processing
    
    Args:
        image: Input image as numpy array
        prompt: Optional text prompt for video generation
        selected_model: Selected WAN model variant
        enabled_loras: List of active LoRA adapters
        enable_nsfw: Whether to allow NSFW content generation
        video_length: Target video length in frames
        resolution: Output resolution preset
        progress: Gradio progress tracker
    
    Returns:
        Tuple of (video_path, status_message, generation_metadata)
    """
    try:
        # Step 1: Validate inputs
        progress(0.1, desc="πŸ” Validating inputs...")
        if image is None:
            raise gr.Error("No image provided. Please upload an image to generate video.")
        
        # Step 2: NSFW check if enabled
        if enable_nsfw and NSFW_CONFIG["require_confirmation"]:
            progress(0.15, desc="⚠️ NSFW mode active - bypassing standard filters")
        elif not enable_nsfw:
            progress(0.15, desc="πŸ›‘οΈ Running safety checks...")
            if nsfw_checker.check_image(image):
                raise gr.Error("Input image flagged by safety filter. Enable NSFW mode to bypass.")
        
        # Step 3: Load selected model and LoRAs
        progress(0.2, desc=f"πŸ“¦ Loading {selected_model} model...")
        generator.load_model(selected_model)
        
        progress(0.3, desc=f"πŸ”Œ Activating {len(enabled_loras)} LoRA adapters...")
        active_loras = lora_manager.load_loras(enabled_loras)
        
        # Step 4: Generate video frames
        progress(0.4, desc="🎬 Generating video frames...")
        frames = []
        for i in range(video_length):
            progress(0.4 + (i / video_length) * 0.5, 
                    desc=f"Rendering frame {i+1}/{video_length}...")
            frame = generator.generate_frame(
                image=image,
                prompt=prompt,
                frame_index=i,
                total_frames=video_length,
                active_loras=active_loras
            )
            frames.append(frame)
            time.sleep(0.1)  # Simulate processing time
        
        # Step 5: Compile video
        progress(0.95, desc="πŸŽ₯ Compiling final video...")
        output_path = generator.compile_video(
            frames=frames,
            resolution=resolution,
            fps=30
        )
        
        # Step 6: Prepare metadata
        metadata = {
            "model": selected_model,
            "loras": enabled_loras,
            "nsfw_mode": enable_nsfw,
            "resolution": resolution,
            "frames": video_length,
            "prompt": prompt or "No prompt provided",
            "status": "βœ… Generation complete"
        }
        
        progress(1.0, desc="βœ… Done!")
        return output_path, "Video generated successfully!", metadata
    
    except Exception as e:
        raise gr.Error(f"Generation failed: {str(e)}")

def update_lora_visibility(enable_nsfw: bool) -> Dict[str, Any]:
    """Update LoRA options based on NSFW mode"""
    if enable_nsfw:
        return gr.Dropdown(
            choices=list(AVAILABLE_LORAS.keys()),
            value=[],
            multiselect=True,
            label="🎨 Active LoRA Adapters (NSFW options unlocked)"
        )
    else:
        safe_loras = {k: v for k, v in AVAILABLE_LORAS.items() if not v.get("nsfw", False)}
        return gr.Dropdown(
            choices=list(safe_loras.keys()),
            value=[],
            multiselect=True,
            label="🎨 Active LoRA Adapters (Safe mode)"
        )

def create_interface():
    """Create the main Gradio interface"""
    
    with gr.Blocks() as demo:
        gr.HTML("""
        <div style='text-align: center; padding: 20px;'>
            <h1>πŸ₯Š WAN-Scale Image-to-Video Architecture πŸ₯Š</h1>
            <p>Built with anycoder - <a href='https://huggingface.co/spaces/akhaliq/anycoder' target='_blank'>View on Hugging Face</a></p>
            <p style='font-size: 1.2em; color: #666;'>Turn static images into dynamic videos with WAN foundation models</p>
        </div>
        """)
        
        # Global state
        generation_state = gr.State({"session_id": None})
        
        with gr.Row():
            # Sidebar for controls
            with gr.Sidebar(position="left", width=320):
                gr.Markdown("### βš™οΈ Generation Settings")
                
                model_selector = gr.Dropdown(
                    choices=list(MODEL_CONFIGS.keys()),
                    value="wan-2.1-14b",
                    label="πŸ€– WAN Model",
                    info="Select foundation model variant"
                )
                
                nsfw_toggle = gr.Checkbox(
                    value=False,
                    label="πŸ”ž Enable NSFW Content",
                    info="Bypass safety filters (requires confirmation)"
                )
                
                lora_selector = gr.Dropdown(
                    choices=[k for k, v in AVAILABLE_LORAS.items() if not v.get("nsfw", False)],
                    value=[],
                    multiselect=True,
                    label="🎨 Active LoRA Adapters",
                    info="Select style and domain adapters"
                )
                
                with gr.Accordion("πŸ“ Video Settings", open=False):
                    video_length = gr.Slider(
                        minimum=16,
                        maximum=128,
                        value=32,
                        step=8,
                        label="Video Length (frames)"
                    )
                    
                    resolution = gr.Radio(
                        choices=["512x512", "768x768", "1024x576", "1920x1080"],
                        value="768x768",
                        label="Resolution"
                    )
                
                with gr.Accordion("πŸš€ Advanced Options", open=False):
                    inference_steps = gr.Slider(
                        minimum=10,
                        maximum=100,
                        value=50,
                        label="Inference Steps"
                    )
                    
                    cfg_scale = gr.Slider(
                        minimum=1.0,
                        maximum=20.0,
                        value=7.5,
                        step=0.5,
                        label="CFG Scale"
                    )
                
                # Status indicators
                model_status = gr.Label(
                    value={"Status": "Ready", "VRAM": "24GB Available"},
                    label="System Status"
                )
            
            # Main content area
            with gr.Column():
                gr.Markdown("### πŸ“€ Input Image")
                input_image = gr.Image(
                    label="Upload Starting Frame",
                    type="numpy",
                    height=400,
                    sources=["upload", "webcam", "clipboard"]
                )
                
                gr.Markdown("### πŸ“ Optional Text Prompt")
                prompt_box = gr.Textbox(
                    placeholder="Describe the motion, style, or scene...",
                    label="Prompt (optional)",
                    lines=2,
                    max_lines=4
                )
                
                with gr.Row():
                    generate_btn = gr.Button(
                        "🎬 Generate Video",
                        variant="primary",
                        scale=2
                    )
                    clear_btn = gr.ClearButton(
                        components=[input_image, prompt_box],
                        value="πŸ—‘οΈ Clear"
                    )
                
                # Progress tracking
                progress_bar = gr.Progress()
                status_text = gr.Textbox(
                    label="Status",
                    interactive=False,
                    show_copy_button=True
                )
                
                gr.Markdown("### πŸ“Ό Output Video")
                output_video = gr.Video(
                    label="Generated Video",
                    height=400,
                    autoplay=True,
                    show_download_button=True
                )
                
                # Generation metadata
                with gr.Accordion("πŸ“Š Generation Details", open=False):
                    metadata_json = gr.JSON(
                        label="Metadata",
                        open=False
                    )
        
        # Event handlers
        nsfw_toggle.change(
            fn=update_lora_visibility,
            inputs=nsfw_toggle,
            outputs=lora_selector,
            api_visibility="private"
        )
        
        generate_btn.click(
            fn=generate_video,
            inputs=[
                input_image,
                prompt_box,
                model_selector,
                lora_selector,
                nsfw_toggle,
                video_length,
                resolution
            ],
            outputs=[
                output_video,
                status_text,
                metadata_json
            ],
            api_visibility="public",
            concurrency_limit=2  # Limit concurrent generations
        )
        
        # Update model status on selection
        model_selector.change(
            fn=lambda x: {"Status": f"Loaded {x}", "VRAM": "24GB Used"},
            inputs=model_selector,
            outputs=model_status,
            api_visibility="private"
        )
        
        # Demo load event
        demo.load(
            fn=lambda: "System initialized and ready",
            outputs=status_text,
            api_visibility="private"
        )
    
    return demo

# Create and launch the application
if __name__ == "__main__":
    demo = create_interface()
    
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        debug=False,
        show_error=True,
        max_threads=4,
        theme=gr.themes.Soft(
            primary_hue="purple",
            secondary_hue="indigo",
            neutral_hue="slate",
            font=gr.themes.GoogleFont("Inter"),
            text_size="lg",
            spacing_size="lg",
            radius_size="md"
        ).set(
            button_primary_background_fill="*primary_600",
            button_primary_background_fill_hover="*primary_700",
            block_title_text_weight="600",
            block_background_fill="*neutral_50"
        ),
        footer_links=[
            {"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"},
            {"label": "Model Docs", "url": "https://huggingface.co/docs"},
            {"label": "API Reference", "url": "/docs"}
        ],
        css="""
        .gradio-container { max-width: 1400px; margin: auto; }
        .contain { display: flex; flex-direction: column; height: 100vh; }
        #component-0 { height: 100%; }
        .gr-button { font-weight: 600; }
        .gr-markdown { text-align: center; }
        """
    )