crash10155 commited on
Commit
896d5aa
·
verified ·
1 Parent(s): 743bf71

Update SwitcherAI/utilities.py

Browse files
Files changed (1) hide show
  1. SwitcherAI/utilities.py +57 -2
SwitcherAI/utilities.py CHANGED
@@ -16,7 +16,39 @@ from tqdm import tqdm
16
  import SwitcherAI.globals
17
  from SwitcherAI import wording
18
 
19
- TEMP_DIRECTORY_PATH = r"D:\Switcher\Temp\SwitcherAI\resize-vid"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  TEMP_OUTPUT_NAME = 'temp.mp4'
21
 
22
  # monkey patch ssl
@@ -106,7 +138,9 @@ def get_temp_frame_paths(target_path : str) -> List[str]:
106
 
107
  def get_temp_directory_path(target_path : str) -> str:
108
  target_name, _ = os.path.splitext(os.path.basename(target_path))
109
- return os.path.join(TEMP_DIRECTORY_PATH, target_name)
 
 
110
 
111
 
112
  def get_temp_output_path(target_path : str) -> str:
@@ -188,3 +222,24 @@ def encode_execution_providers(execution_providers : List[str]) -> List[str]:
188
 
189
  def decode_execution_providers(execution_providers : List[str]) -> List[str]:
190
  return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers())) if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  import SwitcherAI.globals
17
  from SwitcherAI import wording
18
 
19
+ # Dynamic temp directory path - cross-platform compatible
20
+ def get_base_temp_directory():
21
+ """Get base temp directory - cross-platform compatible"""
22
+ if hasattr(SwitcherAI.globals, 'base_temp_directory') and SwitcherAI.globals.base_temp_directory:
23
+ return SwitcherAI.globals.base_temp_directory
24
+
25
+ # Default temp directory based on environment
26
+ if os.getenv('SPACE_ID'): # HuggingFace Spaces
27
+ base_temp = os.path.join(tempfile.gettempdir(), "SwitcherAI")
28
+ elif platform.system().lower() == 'windows':
29
+ # Try to use the original Windows path if it exists and is writable
30
+ original_path = r"D:\Switcher\Temp\SwitcherAI"
31
+ if os.path.exists(os.path.dirname(original_path)):
32
+ try:
33
+ os.makedirs(original_path, exist_ok=True)
34
+ # Test if writable
35
+ test_file = os.path.join(original_path, "test_write.tmp")
36
+ with open(test_file, 'w') as f:
37
+ f.write("test")
38
+ os.remove(test_file)
39
+ base_temp = original_path
40
+ except (OSError, PermissionError):
41
+ base_temp = os.path.join(tempfile.gettempdir(), "SwitcherAI")
42
+ else:
43
+ base_temp = os.path.join(tempfile.gettempdir(), "SwitcherAI")
44
+ else: # Linux/Mac
45
+ base_temp = os.path.join(tempfile.gettempdir(), "SwitcherAI")
46
+
47
+ os.makedirs(base_temp, exist_ok=True)
48
+ return base_temp
49
+
50
+ # Use dynamic temp directory
51
+ TEMP_DIRECTORY_PATH = get_base_temp_directory()
52
  TEMP_OUTPUT_NAME = 'temp.mp4'
53
 
54
  # monkey patch ssl
 
138
 
139
  def get_temp_directory_path(target_path : str) -> str:
140
  target_name, _ = os.path.splitext(os.path.basename(target_path))
141
+ # Use the dynamic temp directory instead of hardcoded path
142
+ base_temp = get_base_temp_directory()
143
+ return os.path.join(base_temp, target_name)
144
 
145
 
146
  def get_temp_output_path(target_path : str) -> str:
 
222
 
223
  def decode_execution_providers(execution_providers : List[str]) -> List[str]:
224
  return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers())) if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
225
+
226
+
227
+ def set_temp_directory(temp_path: str) -> None:
228
+ """Set custom temp directory path"""
229
+ global TEMP_DIRECTORY_PATH
230
+ SwitcherAI.globals.base_temp_directory = temp_path
231
+ TEMP_DIRECTORY_PATH = temp_path
232
+ os.makedirs(temp_path, exist_ok=True)
233
+ print(f"✅ Temp directory set to: {temp_path}")
234
+
235
+
236
+ def get_temp_directory_info() -> dict:
237
+ """Get temp directory information for debugging"""
238
+ return {
239
+ 'base_temp': get_base_temp_directory(),
240
+ 'current_temp': TEMP_DIRECTORY_PATH,
241
+ 'platform': platform.system(),
242
+ 'is_hf_spaces': bool(os.getenv('SPACE_ID')),
243
+ 'temp_exists': os.path.exists(TEMP_DIRECTORY_PATH),
244
+ 'temp_writable': os.access(TEMP_DIRECTORY_PATH, os.W_OK) if os.path.exists(TEMP_DIRECTORY_PATH) else False
245
+ }