Spaces:
Paused
Paused
Update facefusion/installer.py
Browse files- facefusion/installer.py +6 -18
facefusion/installer.py
CHANGED
|
@@ -21,9 +21,6 @@ ONNXRUNTIMES : Dict[str, Tuple[str, str]] =\
|
|
| 21 |
{
|
| 22 |
'default': ('onnxruntime', '1.16.3')
|
| 23 |
}
|
| 24 |
-
# Hardcoding GPU configurations
|
| 25 |
-
torch = 'cuda'
|
| 26 |
-
onnxruntime = 'cuda'
|
| 27 |
|
| 28 |
if platform.system().lower() == 'linux' or platform.system().lower() == 'windows':
|
| 29 |
TORCH['cuda'] = 'cu118'
|
|
@@ -56,21 +53,12 @@ def run(program : ArgumentParser) -> None:
|
|
| 56 |
|
| 57 |
if not args.skip_venv:
|
| 58 |
os.environ['PIP_REQUIRE_VIRTUALENV'] = '1'
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
else:
|
| 66 |
-
answers = inquirer.prompt(
|
| 67 |
-
[
|
| 68 |
-
inquirer.List('torch', message = wording.get('install_dependency_help').format(dependency = 'torch'), choices = list(TORCH.keys())),
|
| 69 |
-
inquirer.List('onnxruntime', message = wording.get('install_dependency_help').format(dependency = 'onnxruntime'), choices = list(ONNXRUNTIMES.keys()))
|
| 70 |
-
])
|
| 71 |
-
if answers:
|
| 72 |
-
torch = answers['torch']
|
| 73 |
-
torch_wheel = TORCH[torch]
|
| 74 |
onnxruntime = answers['onnxruntime']
|
| 75 |
onnxruntime_name, onnxruntime_version = ONNXRUNTIMES[onnxruntime]
|
| 76 |
|
|
|
|
| 21 |
{
|
| 22 |
'default': ('onnxruntime', '1.16.3')
|
| 23 |
}
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
if platform.system().lower() == 'linux' or platform.system().lower() == 'windows':
|
| 26 |
TORCH['cuda'] = 'cu118'
|
|
|
|
| 53 |
|
| 54 |
if not args.skip_venv:
|
| 55 |
os.environ['PIP_REQUIRE_VIRTUALENV'] = '1'
|
| 56 |
+
answers = {
|
| 57 |
+
'torch': 'gpu', # Assuming 'yes' for torch GPU configuration
|
| 58 |
+
'onnxruntime': 'gpu' # Assuming 'yes' for onnxruntime GPU configuration
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
torch_wheel = TORCH[answers['torch']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
onnxruntime = answers['onnxruntime']
|
| 63 |
onnxruntime_name, onnxruntime_version = ONNXRUNTIMES[onnxruntime]
|
| 64 |
|