openfree commited on
Commit
3e75e0e
Β·
verified Β·
1 Parent(s): 0d45b9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -3
app.py CHANGED
@@ -1,12 +1,41 @@
1
  import os
 
 
2
 
3
- # Set this environment variable to disable torch.compiler features
4
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
5
  os.environ["TRANSFORMERS_COMPILER_DISABLED"] = "1"
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import yaml
8
  import torch
9
- import sys
10
  sys.path.append(os.path.abspath('./'))
11
  from inference.utils import *
12
  from train import WurstCoreB
@@ -18,7 +47,6 @@ import argparse
18
  import gradio as gr
19
  import spaces
20
  from huggingface_hub import hf_hub_url
21
- import subprocess
22
  from huggingface_hub import hf_hub_download
23
  from transformers import pipeline
24
 
 
1
  import os
2
+ import subprocess
3
+ import sys
4
 
5
+ # ν•„μš”ν•œ ν™˜κ²½ λ³€μˆ˜ μ„€μ •
6
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
7
  os.environ["TRANSFORMERS_COMPILER_DISABLED"] = "1"
8
 
9
+ # ν•„μš”ν•œ λͺ¨λ“ˆ μ„€μΉ˜ ν•¨μˆ˜
10
+ def install_required_packages():
11
+ required_packages = [
12
+ "warmup_scheduler",
13
+ "cosine_annealing_warmup_restarts"
14
+ ]
15
+
16
+ for package in required_packages:
17
+ try:
18
+ __import__(package)
19
+ print(f"{package} is already installed")
20
+ except ImportError:
21
+ print(f"Installing {package}...")
22
+ try:
23
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
24
+ print(f"{package} installed successfully")
25
+ except subprocess.CalledProcessError:
26
+ # 일뢀 νŒ¨ν‚€μ§€λŠ” PyPI에 없을 수 μžˆμœΌλ―€λ‘œ GitHubμ—μ„œ 직접 μ„€μΉ˜
27
+ if package == "warmup_scheduler":
28
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git"])
29
+ print(f"{package} installed from GitHub successfully")
30
+ else:
31
+ print(f"Failed to install {package}")
32
+
33
+ # ν•„μš”ν•œ λͺ¨λ“ˆ μ„€μΉ˜
34
+ install_required_packages()
35
+
36
+ # κ·Έ ν›„ λ‚˜λ¨Έμ§€ imports
37
  import yaml
38
  import torch
 
39
  sys.path.append(os.path.abspath('./'))
40
  from inference.utils import *
41
  from train import WurstCoreB
 
47
  import gradio as gr
48
  import spaces
49
  from huggingface_hub import hf_hub_url
 
50
  from huggingface_hub import hf_hub_download
51
  from transformers import pipeline
52