shekkari21 commited on
Commit
efb85e3
·
1 Parent(s): 3c45764

added files for inference

Browse files
Files changed (4) hide show
  1. .gitignore +37 -0
  2. app.py +45 -5
  3. requirements.txt +1 -0
  4. src/autoencoder.py +26 -1
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Large model files (stored on Hugging Face)
2
+ checkpoints/
3
+ pretrained_weights/
4
+
5
+ # Python
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ *.so
10
+ .Python
11
+ *.egg-info/
12
+ dist/
13
+ build/
14
+
15
+ # Environment
16
+ .env
17
+ .venv
18
+ env/
19
+ venv/
20
+
21
+ # IDE
22
+ .vscode/
23
+ .idea/
24
+ *.swp
25
+ *.swo
26
+ *~
27
+
28
+ # OS
29
+ .DS_Store
30
+ Thumbs.db
31
+
32
+ # Jupyter
33
+ .ipynb_checkpoints/
34
+
35
+ # Logs
36
+ *.log
37
+
app.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  import torchvision.transforms.functional as TF
9
  from pathlib import Path
10
  import sys
 
11
 
12
  # Add src to path
13
  sys.path.insert(0, str(Path(__file__).parent / "src"))
@@ -17,6 +18,9 @@ from autoencoder import get_vqgan
17
  from noiseControl import resshift_schedule
18
  from config import device, T, k, normalize_input, latent_flag, gt_size
19
 
 
 
 
20
  # Global variables for loaded models
21
  model = None
22
  autoencoder = None
@@ -30,9 +34,12 @@ def load_models():
30
  print("Loading models...")
31
 
32
  # Load model checkpoint
33
- checkpoint_path = "checkpoints/ckpts/model_3200.pth" # Update with your checkpoint path
34
- if not Path(checkpoint_path).exists():
35
- # Try to find any checkpoint
 
 
 
36
  ckpt_dir = Path("checkpoints/ckpts")
37
  if ckpt_dir.exists():
38
  checkpoints = list(ckpt_dir.glob("model_*.pth"))
@@ -40,9 +47,42 @@ def load_models():
40
  checkpoint_path = str(checkpoints[-1]) # Use latest
41
  print(f"Using checkpoint: {checkpoint_path}")
42
  else:
43
- raise FileNotFoundError("No model checkpoint found!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  else:
45
- raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  model = FullUNET()
48
  model = model.to(device)
 
8
  import torchvision.transforms.functional as TF
9
  from pathlib import Path
10
  import sys
11
+ from huggingface_hub import hf_hub_download
12
 
13
  # Add src to path
14
  sys.path.insert(0, str(Path(__file__).parent / "src"))
 
18
  from noiseControl import resshift_schedule
19
  from config import device, T, k, normalize_input, latent_flag, gt_size
20
 
21
+ # Hugging Face repo ID for weights
22
+ HF_WEIGHTS_REPO_ID = "shekkari21/DiffusionSR-weights"
23
+
24
  # Global variables for loaded models
25
  model = None
26
  autoencoder = None
 
34
  print("Loading models...")
35
 
36
  # Load model checkpoint
37
+ checkpoint_path = "checkpoints/ckpts/model_3200.pth"
38
+ checkpoint_file = Path(checkpoint_path)
39
+
40
+ # Download from Hugging Face if not found locally
41
+ if not checkpoint_file.exists():
42
+ # Try to find any checkpoint locally first
43
  ckpt_dir = Path("checkpoints/ckpts")
44
  if ckpt_dir.exists():
45
  checkpoints = list(ckpt_dir.glob("model_*.pth"))
 
47
  checkpoint_path = str(checkpoints[-1]) # Use latest
48
  print(f"Using checkpoint: {checkpoint_path}")
49
  else:
50
+ # Download from Hugging Face
51
+ print(f"Checkpoint not found locally. Downloading from Hugging Face...")
52
+ try:
53
+ # Files are in root of weights repo, download to local directory structure
54
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
55
+ downloaded_path = hf_hub_download(
56
+ repo_id=HF_WEIGHTS_REPO_ID,
57
+ filename="model_3200.pth",
58
+ local_dir=str(ckpt_dir),
59
+ local_dir_use_symlinks=False
60
+ )
61
+ checkpoint_path = str(ckpt_dir / "model_3200.pth")
62
+ print(f"✓ Downloaded checkpoint: {checkpoint_path}")
63
+ except Exception as e:
64
+ raise FileNotFoundError(
65
+ f"Could not download checkpoint from Hugging Face: {e}\n"
66
+ f"Please ensure the file exists in the repository."
67
+ )
68
  else:
69
+ # Create directory and download
70
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
71
+ print(f"Checkpoint not found locally. Downloading from Hugging Face...")
72
+ try:
73
+ downloaded_path = hf_hub_download(
74
+ repo_id=HF_WEIGHTS_REPO_ID,
75
+ filename="model_3200.pth",
76
+ local_dir=str(ckpt_dir),
77
+ local_dir_use_symlinks=False
78
+ )
79
+ checkpoint_path = str(ckpt_dir / "model_3200.pth")
80
+ print(f"✓ Downloaded checkpoint: {checkpoint_path}")
81
+ except Exception as e:
82
+ raise FileNotFoundError(
83
+ f"Could not download checkpoint from Hugging Face: {e}\n"
84
+ f"Please ensure the file exists in the repository."
85
+ )
86
 
87
  model = FullUNET()
88
  model = model.to(device)
requirements.txt CHANGED
@@ -10,4 +10,5 @@ lpips>=0.1.4
10
  loralib>=0.1.2
11
  python-dotenv>=1.0.0
12
  numpy>=1.24.0
 
13
 
 
10
  loralib>=0.1.2
11
  python-dotenv>=1.0.0
12
  numpy>=1.24.0
13
+ huggingface_hub>=0.20.0
14
 
src/autoencoder.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
6
  from pathlib import Path
7
  import sys
8
  import os
 
9
 
10
  # Handle import of ldm from latent-diffusion repository
11
  # Check if ldm directory exists locally (from latent-diffusion repo)
@@ -49,6 +50,9 @@ from config import (
49
  device
50
  )
51
 
 
 
 
52
 
53
  def load_vqgan(ckpt_path=None, device=device):
54
  """
@@ -68,8 +72,29 @@ def load_vqgan(ckpt_path=None, device=device):
68
  if not Path(ckpt_path).is_absolute():
69
  ckpt_path = _project_root / ckpt_path
70
 
 
71
  if not Path(ckpt_path).exists():
72
- raise FileNotFoundError(f"VQGAN checkpoint not found at: {ckpt_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  print(f"Loading VQGAN from: {ckpt_path}")
75
 
 
6
  from pathlib import Path
7
  import sys
8
  import os
9
+ from huggingface_hub import hf_hub_download
10
 
11
  # Handle import of ldm from latent-diffusion repository
12
  # Check if ldm directory exists locally (from latent-diffusion repo)
 
50
  device
51
  )
52
 
53
+ # Hugging Face repo ID for weights
54
+ HF_WEIGHTS_REPO_ID = "shekkari21/DiffusionSR-weights"
55
+
56
 
57
  def load_vqgan(ckpt_path=None, device=device):
58
  """
 
72
  if not Path(ckpt_path).is_absolute():
73
  ckpt_path = _project_root / ckpt_path
74
 
75
+ # Download from Hugging Face if not found locally
76
  if not Path(ckpt_path).exists():
77
+ print(f"VQGAN checkpoint not found locally. Downloading from Hugging Face...")
78
+ try:
79
+ # Files are in root of weights repo, download to local directory structure
80
+ local_weights_dir = _project_root / "pretrained_weights"
81
+ local_weights_dir.mkdir(parents=True, exist_ok=True)
82
+
83
+ # Download from root of weights repo
84
+ downloaded_path = hf_hub_download(
85
+ repo_id=HF_WEIGHTS_REPO_ID,
86
+ filename="autoencoder_vq_f4.pth",
87
+ local_dir=str(local_weights_dir),
88
+ local_dir_use_symlinks=False
89
+ )
90
+ ckpt_path = local_weights_dir / "autoencoder_vq_f4.pth"
91
+ print(f"✓ Downloaded VQGAN checkpoint: {ckpt_path}")
92
+ except Exception as e:
93
+ raise FileNotFoundError(
94
+ f"VQGAN checkpoint not found at: {ckpt_path}\n"
95
+ f"Could not download from Hugging Face: {e}\n"
96
+ f"Please ensure the file exists in the repository."
97
+ )
98
 
99
  print(f"Loading VQGAN from: {ckpt_path}")
100