Onise commited on
Commit
20d7a17
·
verified ·
1 Parent(s): e46a321

Update utils.py from anycoder

Browse files
Files changed (1) hide show
  1. utils.py +111 -0
utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ from typing import Optional
4
+ from pathlib import Path
5
+ import torch
6
+
7
+ def get_device_info() -> dict:
8
+ """Get information about available devices."""
9
+ info = {
10
+ "cuda_available": torch.cuda.is_available(),
11
+ "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
12
+ "current_device": None,
13
+ "device_name": None,
14
+ "memory_allocated": None,
15
+ "memory_reserved": None,
16
+ }
17
+
18
+ if torch.cuda.is_available():
19
+ info["current_device"] = torch.cuda.current_device()
20
+ info["device_name"] = torch.cuda.get_device_name(0)
21
+ info["memory_allocated"] = torch.cuda.memory_allocated(0) / 1024**3 # GB
22
+ info["memory_reserved"] = torch.cuda.memory_reserved(0) / 1024**3 # GB
23
+
24
+ return info
25
+
26
+
27
+ def format_time(seconds: float) -> str:
28
+ """Format seconds into human-readable time string."""
29
+ if seconds < 60:
30
+ return f"{seconds:.1f}s"
31
+ elif seconds < 3600:
32
+ minutes = seconds / 60
33
+ return f"{minutes:.1f}m"
34
+ else:
35
+ hours = seconds / 3600
36
+ return f"{hours:.1f}h"
37
+
38
+
39
+ def estimate_generation_time(
40
+ num_frames: int,
41
+ height: int,
42
+ width: int,
43
+ num_steps: int,
44
+ has_lora: bool = False
45
+ ) -> float:
46
+ """Estimate generation time based on parameters."""
47
+ # Base time estimation (approximate for A100)
48
+ base_time_per_frame = 0.5 # seconds per frame per step
49
+
50
+ # Calculate total operations
51
+ total_pixels = height * width * num_frames
52
+ base_time = total_pixels * num_steps * base_time_per_frame / (480 * 848 * 25 * 20)
53
+
54
+ # LoRA speedup factor
55
+ if has_lora:
56
+ base_time *= 0.4 # ~2.5x speedup with fast LoRA
57
+
58
+ return base_time
59
+
60
+
61
+ def validate_prompt(prompt: str, max_length: int = 500) -> tuple[bool, str]:
62
+ """Validate the input prompt."""
63
+ if not prompt or not prompt.strip():
64
+ return False, "Prompt cannot be empty"
65
+
66
+ if len(prompt) > max_length:
67
+ return False, f"Prompt too long (max {max_length} characters)"
68
+
69
+ return True, "Valid prompt"
70
+
71
+
72
+ def get_cache_directory() -> Path:
73
+ """Get the cache directory for models and LoRAs."""
74
+ cache_dir = Path.home() / ".cache" / "wan_video_generator"
75
+ cache_dir.mkdir(parents=True, exist_ok=True)
76
+ return cache_dir
77
+
78
+
79
+ def clear_cache():
80
+ """Clear the model cache."""
81
+ import shutil
82
+ cache_dir = get_cache_directory()
83
+ if cache_dir.exists():
84
+ shutil.rmtree(cache_dir)
85
+ cache_dir.mkdir(parents=True, exist_ok=True)
86
+ return "Cache cleared successfully"
87
+
88
+
89
+ def get_memory_usage() -> dict:
90
+ """Get current GPU memory usage."""
91
+ if not torch.cuda.is_available():
92
+ return {"error": "CUDA not available"}
93
+
94
+ return {
95
+ "allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
96
+ "reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
97
+ "max_allocated_gb": torch.cuda.max_memory_allocated(0) / 1024**3,
98
+ "device_name": torch.cuda.get_device_name(0),
99
+ }
100
+
101
+
102
+ def optimize_memory():
103
+ """Optimize GPU memory usage."""
104
+ import gc
105
+
106
+ gc.collect()
107
+ if torch.cuda.is_available():
108
+ torch.cuda.empty_cache()
109
+ torch.cuda.synchronize()
110
+
111
+ return "Memory optimized"