Update sd_simple_kes_v3_fix?/simple_kes_v3.py
Browse files
sd_simple_kes_v3_fix?/simple_kes_v3.py
CHANGED
|
@@ -922,7 +922,14 @@ class SimpleKEScheduler:
|
|
| 922 |
self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
|
| 923 |
# Safely extract scalar for both tensor and float
|
| 924 |
self.change_log.append(self.extract_scalar(self.change))
|
| 925 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 926 |
recent_changes = torch.abs(torch.tensor(self.change_log[-5:]))
|
| 927 |
max_change = torch.max(recent_changes).item()
|
| 928 |
mean_change = torch.mean(recent_changes).item()
|
|
@@ -963,13 +970,20 @@ class SimpleKEScheduler:
|
|
| 963 |
# Start checking for early stopping after minimum steps
|
| 964 |
if i > self.safety_minimum_stop_step and len(self.change_log) > 10:
|
| 965 |
# Calculate variance and dynamic threshold
|
| 966 |
-
self.blended_tensor = torch.tensor(self.prepass_blended_sigmas)
|
| 967 |
-
|
| 968 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 969 |
else:
|
| 970 |
-
|
|
|
|
|
|
|
| 971 |
|
| 972 |
-
self.min_sigma_threshold = self.sigma_variance * self.sigma_variance_scale
|
| 973 |
self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i} ---")
|
| 974 |
self.prepass_log(f"Current Blended Prepass Sigma: {self.prepass_blended_sigma:.6f}")
|
| 975 |
self.prepass_log(f"Sigma Variance: {self.sigma_variance:.6f}")
|
|
@@ -1061,10 +1075,11 @@ class SimpleKEScheduler:
|
|
| 1061 |
# Early Stopping Evaluation
|
| 1062 |
if i > self.safety_minimum_stop_step and len(self.change_log) > 5:
|
| 1063 |
final_target_sigma = self.sigs[-1].item() # or use min(self.sigmas) if preferred
|
| 1064 |
-
|
| 1065 |
-
|
|
|
|
| 1066 |
else:
|
| 1067 |
-
relative_sigma_progress = 0
|
| 1068 |
# Optional: Show variance but no need to stop on it
|
| 1069 |
self.sigma_variance = torch.var(self.sigs).item() if self.device != 'cpu' else np.var(self.blended_sigmas)
|
| 1070 |
self.log(f"Sigma Variance: {self.sigma_variance:.6f}")
|
|
@@ -1920,36 +1935,46 @@ class SimpleKEScheduler:
|
|
| 1920 |
blend_weights = self.blend_weights
|
| 1921 |
|
| 1922 |
)
|
| 1923 |
-
self.sigma_variance = torch.var(self.sigs).item()
|
| 1924 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1925 |
if self.sigma_variance < self.sharpen_variance_threshold:
|
| 1926 |
-
|
| 1927 |
-
self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1928 |
-
sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
|
| 1929 |
-
self.sigs = self.sigs * self.sharpen_mask
|
| 1930 |
-
self.log(f"[Sharpen Mask] Full sharpening applied (low variance). Steps: {sharpen_indices}")
|
| 1931 |
else:
|
| 1932 |
-
# Apply sharpening only to the last N steps
|
| 1933 |
recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
|
| 1934 |
sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1935 |
sharpen_indices = torch.where(sharpen_mask < 1.0)[0].tolist()
|
| 1936 |
self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
|
|
|
|
| 1937 |
|
| 1938 |
-
|
| 1939 |
-
|
| 1940 |
-
|
| 1941 |
-
old_value = self.sigs[j].item()
|
| 1942 |
-
self.sigs[j] = self.sigs[j] * self.sharpness
|
| 1943 |
-
self.log(f"[Sharpening] Step {j+1}: Applied sharpening. Sigma changed from {old_value:.6f} to {self.sigs[j].item():.6f}")
|
| 1944 |
-
else:
|
| 1945 |
-
self.log(f"[Sharpening] Step {j+1}: No sharpening applied. Sigma: {self.sigs[j].item():.6f}")
|
| 1946 |
-
|
| 1947 |
-
if self.sharpen_mode in ['full', 'both']:
|
| 1948 |
-
# Optional: Additional full sharpening (if needed)
|
| 1949 |
self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1950 |
-
|
| 1951 |
self.sigs = self.sigs * self.sharpen_mask
|
| 1952 |
-
self.log(f"[Sharpen Mask]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1953 |
|
| 1954 |
'''
|
| 1955 |
sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
|
|
|
|
| 922 |
self.change = torch.abs(self.sigs[i] - self.sigs[i - 1])
|
| 923 |
# Safely extract scalar for both tensor and float
|
| 924 |
self.change_log.append(self.extract_scalar(self.change))
|
| 925 |
+
# --- ИСПРАВЛЕНИЕ ДЕЛЕНИЯ НА НОЛЬ (v2) ---
|
| 926 |
+
blended_scalar = self.extract_scalar(self.blended_sigma)
|
| 927 |
+
if blended_scalar != 0:
|
| 928 |
+
relative_sigma_progress = (blended_scalar - self.sigs[-1].item()) / blended_scalar
|
| 929 |
+
else:
|
| 930 |
+
relative_sigma_progress = 0.0
|
| 931 |
+
# --- КОНЕЦ ИСПРАВЛЕНИЯ ---
|
| 932 |
+
|
| 933 |
recent_changes = torch.abs(torch.tensor(self.change_log[-5:]))
|
| 934 |
max_change = torch.max(recent_changes).item()
|
| 935 |
mean_change = torch.mean(recent_changes).item()
|
|
|
|
| 970 |
# Start checking for early stopping after minimum steps
|
| 971 |
if i > self.safety_minimum_stop_step and len(self.change_log) > 10:
|
| 972 |
# Calculate variance and dynamic threshold
|
| 973 |
+
self.blended_tensor = torch.tensor(self.prepass_blended_sigmas, device=self.device) # <-- Добавлен device
|
| 974 |
+
|
| 975 |
+
# --- ИСПРАВЛЕНИЕ РАСЧЕТА ДИСПЕРСИИ ---
|
| 976 |
+
# self.sigs еще не заполнен (полон нулей).
|
| 977 |
+
# Мы должны использовать 'blended_tensor', который содержит реальные значения.
|
| 978 |
+
if self.device.type == 'cpu':
|
| 979 |
+
# np.var работает на списке Python
|
| 980 |
+
self.sigma_variance = np.var(self.prepass_blended_sigmas)
|
| 981 |
else:
|
| 982 |
+
# torch.var работает на тензоре
|
| 983 |
+
self.sigma_variance = torch.var(self.blended_tensor).item()
|
| 984 |
+
# --- КОНЕЦ ИСПРАВЛЕНИЯ ---
|
| 985 |
|
| 986 |
+
self.min_sigma_threshold = self.sigma_variance * self.sigma_variance_scale
|
| 987 |
self.prepass_log(f"\n--- Early Stopping Evaluation at Step {i} ---")
|
| 988 |
self.prepass_log(f"Current Blended Prepass Sigma: {self.prepass_blended_sigma:.6f}")
|
| 989 |
self.prepass_log(f"Sigma Variance: {self.sigma_variance:.6f}")
|
|
|
|
| 1075 |
# Early Stopping Evaluation
|
| 1076 |
if i > self.safety_minimum_stop_step and len(self.change_log) > 5:
|
| 1077 |
final_target_sigma = self.sigs[-1].item() # or use min(self.sigmas) if preferred
|
| 1078 |
+
blended_scalar = self.extract_scalar(self.blended_sigma)
|
| 1079 |
+
if blended_scalar != 0:
|
| 1080 |
+
relative_sigma_progress = (blended_scalar - self.sigs[-1].item()) / blended_scalar
|
| 1081 |
else:
|
| 1082 |
+
relative_sigma_progress = 0.0
|
| 1083 |
# Optional: Show variance but no need to stop on it
|
| 1084 |
self.sigma_variance = torch.var(self.sigs).item() if self.device != 'cpu' else np.var(self.blended_sigmas)
|
| 1085 |
self.log(f"Sigma Variance: {self.sigma_variance:.6f}")
|
|
|
|
| 1935 |
blend_weights = self.blend_weights
|
| 1936 |
|
| 1937 |
)
|
| 1938 |
+
self.sigma_variance = torch.var(self.sigs).item()
|
| 1939 |
+
# --- НАЧАЛО ИСПРАВЛЕНИЯ ДВОЙНОЙ РЕЗКОСТИ ---
|
| 1940 |
+
# Логика для 'full': применяется только 'full'
|
| 1941 |
+
if self.sharpen_mode == 'full':
|
| 1942 |
+
self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1943 |
+
sharpen_indices = torch.where(self.sharpen_mask < 1.0)[0].tolist()
|
| 1944 |
+
self.sigs = self.sigs * self.sharpen_mask
|
| 1945 |
+
self.log(f"[Sharpen Mask] 'full' sharpening applied at steps: {sharpen_indices}")
|
| 1946 |
+
|
| 1947 |
+
# Логика для 'last_n': применяется только 'last_n'
|
| 1948 |
+
elif self.sharpen_mode == 'last_n':
|
| 1949 |
+
# Примечание: эта проверка дисперсии взята из оригинальной (странной) логики
|
| 1950 |
if self.sigma_variance < self.sharpen_variance_threshold:
|
| 1951 |
+
self.log(f"[Sharpen Mask] 'last_n' mode, but variance is low ({self.sigma_variance:.6f} < {self.sharpen_variance_threshold}). Skipping 'last_n'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1952 |
else:
|
|
|
|
| 1953 |
recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
|
| 1954 |
sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1955 |
sharpen_indices = torch.where(sharpen_mask < 1.0)[0].tolist()
|
| 1956 |
self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
|
| 1957 |
+
self.log(f"[Sharpen Mask] 'last_n' sharpening applied at steps: {sharpen_indices}")
|
| 1958 |
|
| 1959 |
+
# Логика для 'both': применяется 'full', А ЗАТЕМ 'last_n'
|
| 1960 |
+
elif self.sharpen_mode == 'both':
|
| 1961 |
+
# 1. Применяем 'full'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1962 |
self.sharpen_mask = torch.where(self.sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1963 |
+
sharpen_indices_full = torch.where(self.sharpen_mask < 1.0)[0].tolist()
|
| 1964 |
self.sigs = self.sigs * self.sharpen_mask
|
| 1965 |
+
self.log(f"[Sharpen Mask] 'both' (part 1/2): 'full' sharpening applied at steps: {sharpen_indices_full}")
|
| 1966 |
+
|
| 1967 |
+
# 2. Применяем 'last_n' (с той же проверкой дисперсии)
|
| 1968 |
+
if self.sigma_variance < self.sharpen_variance_threshold:
|
| 1969 |
+
self.log(f"[Sharpen Mask] 'both' (part 2/2): variance is low ({self.sigma_variance:.6f} < {self.sharpen_variance_threshold}). Skipping 'last_n' part.")
|
| 1970 |
+
else:
|
| 1971 |
+
recent_sigs = self.sigs[-self.sharpen_last_n_steps:]
|
| 1972 |
+
sharpen_mask = torch.where(recent_sigs < self.sigma_min * 1.5, self.sharpness, 1.0).to(self.device)
|
| 1973 |
+
sharpen_indices_last_n = torch.where(sharpen_mask < 1.0)[0].tolist()
|
| 1974 |
+
self.sigs[-self.sharpen_last_n_steps:] = recent_sigs * sharpen_mask
|
| 1975 |
+
self.log(f"[Sharpen Mask] 'both' (part 2/2): 'last_n' sharpening applied at steps: {sharpen_indices_last_n}")
|
| 1976 |
+
|
| 1977 |
+
# --- КОНЕЦ ИСПРАВЛЕНИЯ ДВОЙНОЙ РЕЗКОСТИ ---
|
| 1978 |
|
| 1979 |
'''
|
| 1980 |
sigma_hash = self.generate_sigma_hash(steps, sigma_min, sigma_max, rho, device, schedule_type, decay_pattern, suffix=None)
|