File size: 3,241 Bytes
20d7a17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import hashlib
from typing import Optional
from pathlib import Path
import torch

def get_device_info() -> dict:
    """Get information about available devices."""
    info = {
        "cuda_available": torch.cuda.is_available(),
        "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
        "current_device": None,
        "device_name": None,
        "memory_allocated": None,
        "memory_reserved": None,
    }
    
    if torch.cuda.is_available():
        info["current_device"] = torch.cuda.current_device()
        info["device_name"] = torch.cuda.get_device_name(0)
        info["memory_allocated"] = torch.cuda.memory_allocated(0) / 1024**3  # GB
        info["memory_reserved"] = torch.cuda.memory_reserved(0) / 1024**3  # GB
    
    return info


def format_time(seconds: float) -> str:
    """Format seconds into human-readable time string."""
    if seconds < 60:
        return f"{seconds:.1f}s"
    elif seconds < 3600:
        minutes = seconds / 60
        return f"{minutes:.1f}m"
    else:
        hours = seconds / 3600
        return f"{hours:.1f}h"


def estimate_generation_time(
    num_frames: int,
    height: int,
    width: int,
    num_steps: int,
    has_lora: bool = False
) -> float:
    """Estimate generation time based on parameters."""
    # Base time estimation (approximate for A100)
    base_time_per_frame = 0.5  # seconds per frame per step
    
    # Calculate total operations
    total_pixels = height * width * num_frames
    base_time = total_pixels * num_steps * base_time_per_frame / (480 * 848 * 25 * 20)
    
    # LoRA speedup factor
    if has_lora:
        base_time *= 0.4  # ~2.5x speedup with fast LoRA
    
    return base_time


def validate_prompt(prompt: str, max_length: int = 500) -> tuple[bool, str]:
    """Validate the input prompt."""
    if not prompt or not prompt.strip():
        return False, "Prompt cannot be empty"
    
    if len(prompt) > max_length:
        return False, f"Prompt too long (max {max_length} characters)"
    
    return True, "Valid prompt"


def get_cache_directory() -> Path:
    """Get the cache directory for models and LoRAs."""
    cache_dir = Path.home() / ".cache" / "wan_video_generator"
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


def clear_cache():
    """Clear the model cache."""
    import shutil
    cache_dir = get_cache_directory()
    if cache_dir.exists():
        shutil.rmtree(cache_dir)
    cache_dir.mkdir(parents=True, exist_ok=True)
    return "Cache cleared successfully"


def get_memory_usage() -> dict:
    """Get current GPU memory usage."""
    if not torch.cuda.is_available():
        return {"error": "CUDA not available"}
    
    return {
        "allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
        "reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
        "max_allocated_gb": torch.cuda.max_memory_allocated(0) / 1024**3,
        "device_name": torch.cuda.get_device_name(0),
    }


def optimize_memory():
    """Optimize GPU memory usage."""
    import gc
    
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    return "Memory optimized"