akin23 commited on
Commit
25fc296
·
verified ·
1 Parent(s): 1c96585

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +21 -10
src/facerender/animate.py CHANGED
@@ -153,33 +153,43 @@ class AnimateFromCoeff:
153
  optimizer_mapping=None, optimizer_discriminator=None,
154
  device='cpu'):
155
 
156
- # 1) .tar ise içeriği ve .pth bul
157
- if checkpoint_path.endswith(".tar"):
158
  tmpdir = tempfile.mkdtemp()
159
- with tarfile.open(checkpoint_path, "r") as tar:
160
  tar.extractall(path=tmpdir)
 
 
161
  found = False
162
  for root, _, files in os.walk(tmpdir):
163
  for fname in files:
164
- if fname.endswith(".pth"):
165
  checkpoint_path = os.path.join(root, fname)
166
  found = True
167
  break
168
  if found:
169
  break
170
- if not found:
171
- raise FileNotFoundError(f"{checkpoint_path} içinde .pth dosyası bulunamadı.")
172
 
173
- # 2) Klasör yüklendiyse archive/data.pkl’e bak
 
 
 
 
 
 
 
 
 
 
174
  if os.path.isdir(checkpoint_path):
175
- possible = os.path.join(checkpoint_path, "archive", "data.pkl")
176
  if os.path.isfile(possible):
177
  checkpoint_path = possible
178
 
179
- # 3) checkpoint’i yükle
180
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
181
 
182
- # 4) State dict’leri ata
183
  if mapping is not None and 'mapping' in checkpoint:
184
  mapping.load_state_dict(checkpoint['mapping'])
185
  if discriminator is not None and 'discriminator' in checkpoint:
@@ -189,4 +199,5 @@ class AnimateFromCoeff:
189
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
190
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
191
 
 
192
  return checkpoint.get('epoch', 0)
 
153
  optimizer_mapping=None, optimizer_discriminator=None,
154
  device='cpu'):
155
 
156
+ # 1) Eğer .tar ile bitiyorsa, önce geçici klasöre
157
+ if checkpoint_path.endswith('.tar'):
158
  tmpdir = tempfile.mkdtemp()
159
+ with tarfile.open(checkpoint_path, 'r') as tar:
160
  tar.extractall(path=tmpdir)
161
+
162
+ # 1.a) İçerikte .pth dosyası aranır
163
  found = False
164
  for root, _, files in os.walk(tmpdir):
165
  for fname in files:
166
+ if fname.endswith('.pth'):
167
  checkpoint_path = os.path.join(root, fname)
168
  found = True
169
  break
170
  if found:
171
  break
 
 
172
 
173
+ # 1.b) .pth bulunamazsa data/data.pkl’e bak
174
+ if not found:
175
+ pkl_path = os.path.join(tmpdir, 'data', 'data.pkl')
176
+ if os.path.isfile(pkl_path):
177
+ checkpoint_path = pkl_path
178
+ else:
179
+ raise FileNotFoundError(
180
+ f"{checkpoint_path}.tar içinde ne .pth ne de data/data.pkl bulundu."
181
+ )
182
+
183
+ # 2) Eğer doğrudan bir klasör yolu verilmişse archive/data.pkl’e bak
184
  if os.path.isdir(checkpoint_path):
185
+ possible = os.path.join(checkpoint_path, 'archive', 'data.pkl')
186
  if os.path.isfile(possible):
187
  checkpoint_path = possible
188
 
189
+ # 3) Torch ile checkpoint’i yükle
190
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
191
 
192
+ # 4) State dict’lerini ilgili objelere ata
193
  if mapping is not None and 'mapping' in checkpoint:
194
  mapping.load_state_dict(checkpoint['mapping'])
195
  if discriminator is not None and 'discriminator' in checkpoint:
 
199
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
200
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
201
 
202
+ # 5) Eğer epoch bilgisi varsa döndür, yoksa 0
203
  return checkpoint.get('epoch', 0)