akin23 commited on
Commit
cfa9dce
·
verified ·
1 Parent(s): ab00a91

Update src/facerender/animate.py

Browse files
Files changed (1) hide show
  1. src/facerender/animate.py +22 -23
src/facerender/animate.py CHANGED
@@ -153,43 +153,42 @@ class AnimateFromCoeff:
153
  optimizer_mapping=None, optimizer_discriminator=None,
154
  device='cpu'):
155
 
156
- # 1) Eğer .tar ile bitiyorsa, önce geçici klasöre
157
- if checkpoint_path.endswith(''):
158
  tmpdir = tempfile.mkdtemp()
159
  with tarfile.open(checkpoint_path, 'r') as tar:
160
  tar.extractall(path=tmpdir)
161
 
162
- # 1.a) İçerikte .pth dosyası aranır
163
- found = False
164
  for root, _, files in os.walk(tmpdir):
165
  for fname in files:
166
- if fname.endswith('.pth'):
167
- checkpoint_path = os.path.join(root, fname)
168
- found = True
169
- break
170
- if found:
171
  break
172
 
173
- # 1.b) .pth bulunamazsa data/data.pkl’e bak
174
- if not found:
175
- pkl_path = os.path.join(tmpdir, 'data', 'data.pkl')
176
- if os.path.isfile(pkl_path):
177
- checkpoint_path = pkl_path
178
- else:
179
- raise FileNotFoundError(
180
- f"{checkpoint_path}.tar içinde ne .pth ne de data/data.pkl bulundu."
181
- )
182
-
183
- # 2) Eğer doğrudan bir klasör yolu verilmişse archive/data.pkl’e bak
184
  if os.path.isdir(checkpoint_path):
185
  possible = os.path.join(checkpoint_path, 'archive', 'data.pkl')
186
  if os.path.isfile(possible):
187
  checkpoint_path = possible
188
 
189
- # 3) Torch ile checkpoint’i yükle
190
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
191
 
192
- # 4) State dict’lerini ilgili objelere ata
193
  if mapping is not None and 'mapping' in checkpoint:
194
  mapping.load_state_dict(checkpoint['mapping'])
195
  if discriminator is not None and 'discriminator' in checkpoint:
@@ -199,5 +198,5 @@ class AnimateFromCoeff:
199
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
200
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
201
 
202
- # 5) Eğer epoch bilgisi varsa döndür, yoksa 0
203
  return checkpoint.get('epoch', 0)
 
153
  optimizer_mapping=None, optimizer_discriminator=None,
154
  device='cpu'):
155
 
156
+ # Eğer .tar ile bitiyorsa, önce ve içinden .pth veya .pkl ara
157
+ if checkpoint_path.endswith('.tar'):
158
  tmpdir = tempfile.mkdtemp()
159
  with tarfile.open(checkpoint_path, 'r') as tar:
160
  tar.extractall(path=tmpdir)
161
 
162
+ found_pth = None
163
+ found_pkl = None
164
  for root, _, files in os.walk(tmpdir):
165
  for fname in files:
166
+ if fname.endswith('.pth') and found_pth is None:
167
+ found_pth = os.path.join(root, fname)
168
+ if fname.endswith('.pkl') and found_pkl is None:
169
+ found_pkl = os.path.join(root, fname)
170
+ if found_pth:
171
  break
172
 
173
+ if found_pth:
174
+ checkpoint_path = found_pth
175
+ elif found_pkl:
176
+ checkpoint_path = found_pkl
177
+ else:
178
+ raise FileNotFoundError(
179
+ f"{checkpoint_path} içinden ne .pth ne de .pkl dosyası bulunabildi."
180
+ )
181
+
182
+ # Eğer bir klasör yoluna geldi ise (nadiren kullanılır), archive altındaki data.pkl’e bak
 
183
  if os.path.isdir(checkpoint_path):
184
  possible = os.path.join(checkpoint_path, 'archive', 'data.pkl')
185
  if os.path.isfile(possible):
186
  checkpoint_path = possible
187
 
188
+ # Artık checkpoint_path kesin .pth veya .pkl
189
  checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
190
 
191
+ # State dict’leri ilgili modellere yükle
192
  if mapping is not None and 'mapping' in checkpoint:
193
  mapping.load_state_dict(checkpoint['mapping'])
194
  if discriminator is not None and 'discriminator' in checkpoint:
 
198
  if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint:
199
  optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
200
 
201
+ # Epoch bilgisi varsa döndür, yoksa 0
202
  return checkpoint.get('epoch', 0)