dikdimon commited on
Commit
8e5f1f5
·
verified ·
1 Parent(s): 3b4fdab

Update sd_simple_kes_v3_fix?/simple_kes_v3.py

Browse files
sd_simple_kes_v3_fix?/simple_kes_v3.py CHANGED
@@ -922,7 +922,14 @@ class SimpleKEScheduler:
922
  self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
923
  # Safely extract scalar for both tensor and float
924
  self.change_log.append(self.extract_scalar(self.change))
925
- relative_sigma_progress = (self.blended_sigma - self.sigs[-1].item()) / self.blended_sigma
 
 
 
 
 
 
 
926
  recent_changes = torch.abs(torch.tensor(self.change_log[-5:]))
927
  max_change = torch.max(recent_changes).item()
928
  mean_change = torch.mean(recent_changes).item()
@@ -963,13 +970,20 @@ class SimpleKEScheduler:
963
  # Start checking for early stopping after minimum steps
964
  if i > self.safety_minimum_stop_step and len(self.change_log) > 10:
965
  # Calculate variance and dynamic threshold
966
- self.blended_tensor = torch.tensor(self.prepass_blended_sigmas)
967
- if self.device == 'cpu':
968
- self.sigma_variance = np.var(self.prepass_blended_sigmas)
 
 
 
 
 
969
  else:
970
- self.sigma_variance = torch.var(self.sigs).item()
 
 
971
 
972
- self.min_sigma_threshold = self.sigma_variance * self.sigma_variance_scale # scale factor can be tuned
973
  self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i} ---")
974
  self.prepass_log(f"Current Blended Prepass Sigma: {self.prepass_blended_sigma:.6f}")
975
  self.prepass_log(f"Sigma Variance: {self.sigma_variance:.6f}")
@@ -1061,10 +1075,11 @@ class SimpleKEScheduler:
1061
  # Early Stopping Evaluation
1062
  if i > self.safety_minimum_stop_step and len(self.change_log) > 5:
1063
  final_target_sigma = self.sigs[-1].item() # or use min(self.sigmas) if preferred
1064
- if self.blended_sigma != 0:
1065
- relative_sigma_progress = (self.blended_sigma - final_target_sigma) / self.blended_sigma
 
1066
  else:
1067
- relative_sigma_progress = 0 # Assume fully converged if blended_sigma is 0
1068
  # Optional: Show variance but no need to stop on it
1069
  self.sigma_variance = torch.var(self.sigs).item() if self.device != 'cpu' else np.var(self.blended_sigmas)
1070
  self.log(f"Sigma Variance: {self.sigma_variance:.6f}")
@@ -1920,36 +1935,46 @@ class SimpleKEScheduler:
1920
  blend_weights = self.blend_weights
1921
 
1922
  )
1923
- self.sigma_variance = torch.var(self.sigs).item()
1924
- if self.sharpen_mode in ['last_n', 'both']:
 
 
 
 
 
 
 
 
 
 
1925
  if self.sigma_variance < self.sharpen_variance_threshold:
1926
- # Apply full sharpening
1927
- self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1928
- sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
1929
- self.sigs = self.sigs * self.sharpen_mask
1930
- self.log(f"[Sharpen Mask] Full sharpening applied (low variance). Steps: {sharpen_indices}")
1931
  else:
1932
- # Apply sharpening only to the last N steps
1933
  recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
1934
  sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1935
  sharpen_indices = torch.where(sharpen_mask < 1.0)[0].tolist()
1936
  self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
 
1937
 
1938
- # Now loop per step if desired (safely inside this block)
1939
- for j in range(len(self.sigs) - self.sharpen_last_n_steps, len(self.sigs)):
1940
- if self.sigs[j] < self.sigma_min * 1.5:
1941
- old_value = self.sigs[j].item()
1942
- self.sigs[j] = self.sigs[j] * self.sharpness
1943
- self.log(f"[Sharpening] Step {j+1}: Applied sharpening. Sigma changed from {old_value:.6f} to {self.sigs[j].item():.6f}")
1944
- else:
1945
- self.log(f"[Sharpening] Step {j+1}: No sharpening applied. Sigma: {self.sigs[j].item():.6f}")
1946
-
1947
- if self.sharpen_mode in ['full', 'both']:
1948
- # Optional: Additional full sharpening (if needed)
1949
  self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1950
- sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
1951
  self.sigs = self.sigs * self.sharpen_mask
1952
- self.log(f"[Sharpen Mask] Full sharpening applied at steps: {sharpen_indices}")
 
 
 
 
 
 
 
 
 
 
 
 
1953
 
1954
  '''
1955
  sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
 
922
  self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
923
  # Safely extract scalar for both tensor and float
924
  self.change_log.append(self.extract_scalar(self.change))
925
+ # --- ИСПРАВЛЕНИЕ ДЕЛЕНИЯ НА НОЛЬ (v2) ---
926
+ blended_scalar = self.extract_scalar(self.blended_sigma)
927
+ if blended_scalar != 0:
928
+ relative_sigma_progress = (blended_scalar - self.sigs[-1].item()) / blended_scalar
929
+ else:
930
+ relative_sigma_progress = 0.0
931
+ # --- КОНЕЦ ИСПРАВЛЕНИЯ ---
932
+
933
  recent_changes = torch.abs(torch.tensor(self.change_log[-5:]))
934
  max_change = torch.max(recent_changes).item()
935
  mean_change = torch.mean(recent_changes).item()
 
970
  # Start checking for early stopping after minimum steps
971
  if i > self.safety_minimum_stop_step and len(self.change_log) > 10:
972
  # Calculate variance and dynamic threshold
973
+ self.blended_tensor = torch.tensor(self.prepass_blended_sigmas, device=self.device) # <-- Добавлен device
974
+
975
+ # --- ИСПРАВЛЕНИЕ РАСЧЕТА ДИСПЕРСИИ ---
976
+ # self.sigs еще не заполнен (полон нулей).
977
+ # Мы должны использовать 'blended_tensor', который содержит реальные значения.
978
+ if self.device.type == 'cpu':
979
+ # np.var работает на списке Python
980
+ self.sigma_variance = np.var(self.prepass_blended_sigmas)
981
  else:
982
+ # torch.var работает на тензоре
983
+ self.sigma_variance = torch.var(self.blended_tensor).item()
984
+ # --- КОНЕЦ ИСПРАВЛЕНИЯ ---
985
 
986
+ self.min_sigma_threshold = self.sigma_variance * self.sigma_variance_scale
987
  self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i} ---")
988
  self.prepass_log(f"Current Blended Prepass Sigma: {self.prepass_blended_sigma:.6f}")
989
  self.prepass_log(f"Sigma Variance: {self.sigma_variance:.6f}")
 
1075
  # Early Stopping Evaluation
1076
  if i > self.safety_minimum_stop_step and len(self.change_log) > 5:
1077
  final_target_sigma = self.sigs[-1].item() # or use min(self.sigmas) if preferred
1078
+ blended_scalar = self.extract_scalar(self.blended_sigma)
1079
+ if blended_scalar != 0:
1080
+ relative_sigma_progress = (blended_scalar - self.sigs[-1].item()) / blended_scalar
1081
  else:
1082
+ relative_sigma_progress = 0.0
1083
  # Optional: Show variance but no need to stop on it
1084
  self.sigma_variance = torch.var(self.sigs).item() if self.device != 'cpu' else np.var(self.blended_sigmas)
1085
  self.log(f"Sigma Variance: {self.sigma_variance:.6f}")
 
1935
  blend_weights = self.blend_weights
1936
 
1937
  )
1938
+ self.sigma_variance = torch.var(self.sigs).item()
1939
+ # --- НАЧАЛО ИСПРАВЛЕНИЯ ДВОЙНОЙ РЕЗКОСТИ ---
1940
+ # Логика для 'full': применяется только 'full'
1941
+ if self.sharpen_mode == 'full':
1942
+ self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1943
+ sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
1944
+ self.sigs = self.sigs * self.sharpen_mask
1945
+ self.log(f"[Sharpen Mask] 'full' sharpening applied at steps: {sharpen_indices}")
1946
+
1947
+ # Логика для 'last_n': применяется только 'last_n'
1948
+ elif self.sharpen_mode == 'last_n':
1949
+ # Примечание: эта проверка дисперсии взята из оригинальной (странной) логики
1950
  if self.sigma_variance < self.sharpen_variance_threshold:
1951
+ self.log(f"[Sharpen Mask] 'last_n' mode, but variance is low ({self.sigma_variance:.6f} < {self.sharpen_variance_threshold}). Skipping 'last_n'.")
 
 
 
 
1952
  else:
 
1953
  recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
1954
  sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1955
  sharpen_indices = torch.where(sharpen_mask < 1.0)[0].tolist()
1956
  self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
1957
+ self.log(f"[Sharpen Mask] 'last_n' sharpening applied at steps: {sharpen_indices}")
1958
 
1959
+ # Логика для 'both': применяется 'full', А ЗАТЕМ 'last_n'
1960
+ elif self.sharpen_mode == 'both':
1961
+ # 1. Применяем 'full'
 
 
 
 
 
 
 
 
1962
  self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1963
+ sharpen_indices_full = torch.where(self.sharpen_mask < 1.0)[0].tolist()
1964
  self.sigs = self.sigs * self.sharpen_mask
1965
+ self.log(f"[Sharpen Mask] 'both' (part 1/2): 'full' sharpening applied at steps: {sharpen_indices_full}")
1966
+
1967
+ # 2. Применяем 'last_n' (с той же проверкой дисперсии)
1968
+ if self.sigma_variance < self.sharpen_variance_threshold:
1969
+ self.log(f"[Sharpen Mask] 'both' (part 2/2): variance is low ({self.sigma_variance:.6f} < {self.sharpen_variance_threshold}). Skipping 'last_n' part.")
1970
+ else:
1971
+ recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
1972
+ sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
1973
+ sharpen_indices_last_n = torch.where(sharpen_mask < 1.0)[0].tolist()
1974
+ self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
1975
+ self.log(f"[Sharpen Mask] 'both' (part 2/2): 'last_n' sharpening applied at steps: {sharpen_indices_last_n}")
1976
+
1977
+ # --- КОНЕЦ ИСПРАВЛЕНИЯ ДВОЙНОЙ РЕЗКОСТИ ---
1978
 
1979
  '''
1980
  sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)