| | import functools |
| | import pprint |
| | import random |
| | from pathlib import Path |
| |
|
| | import yaml |
| |
|
| | from modules import shared |
| | from modules.loaders import loaders_samplers |
| | from modules.logging_colors import logger |
| |
|
| |
|
| | def default_preset(): |
| | return { |
| | 'temperature': 1, |
| | 'temperature_last': False, |
| | 'dynamic_temperature': False, |
| | 'dynatemp_low': 1, |
| | 'dynatemp_high': 1, |
| | 'dynatemp_exponent': 1, |
| | 'smoothing_factor': 0, |
| | 'top_p': 1, |
| | 'min_p': 0, |
| | 'top_k': 0, |
| | 'repetition_penalty': 1, |
| | 'presence_penalty': 0, |
| | 'frequency_penalty': 0, |
| | 'repetition_penalty_range': 1024, |
| | 'typical_p': 1, |
| | 'tfs': 1, |
| | 'top_a': 0, |
| | 'epsilon_cutoff': 0, |
| | 'eta_cutoff': 0, |
| | 'guidance_scale': 1, |
| | 'penalty_alpha': 0, |
| | 'mirostat_mode': 0, |
| | 'mirostat_tau': 5, |
| | 'mirostat_eta': 0.1, |
| | 'do_sample': True, |
| | 'encoder_repetition_penalty': 1, |
| | 'no_repeat_ngram_size': 0, |
| | 'min_length': 0, |
| | 'num_beams': 1, |
| | 'length_penalty': 1, |
| | 'early_stopping': False, |
| | 'sampler_priority': 'temperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nmirostat' |
| | } |
| |
|
| |
|
| | def presets_params(): |
| | return [k for k in default_preset()] |
| |
|
| |
|
| | def load_preset(name): |
| | generate_params = default_preset() |
| | if name not in ['None', None, '']: |
| | path = Path(f'presets/{name}.yaml') |
| | if path.exists(): |
| | with open(path, 'r') as infile: |
| | preset = yaml.safe_load(infile) |
| |
|
| | for k in preset: |
| | generate_params[k] = preset[k] |
| | else: |
| | logger.error(f"The preset \"{name}\" does not exist under \"{path}\". Using the default parameters.") |
| |
|
| | return generate_params |
| |
|
| |
|
| | @functools.cache |
| | def load_preset_memoized(name): |
| | return load_preset(name) |
| |
|
| |
|
| | def load_preset_for_ui(name, state): |
| | generate_params = load_preset(name) |
| | state.update(generate_params) |
| | return state, *[generate_params[k] for k in presets_params()] |
| |
|
| |
|
| | def random_preset(state): |
| | params_and_values = { |
| | 'remove_tail_tokens': { |
| | 'top_p': [0.5, 0.8, 0.9, 0.95, 0.99], |
| | 'min_p': [0.5, 0.2, 0.1, 0.05, 0.01], |
| | 'top_k': [3, 5, 10, 20, 30, 40], |
| | 'typical_p': [0.2, 0.575, 0.95], |
| | 'tfs': [0.5, 0.8, 0.9, 0.95, 0.99], |
| | 'top_a': [0.5, 0.2, 0.1, 0.05, 0.01], |
| | 'epsilon_cutoff': [1, 3, 5, 7, 9], |
| | 'eta_cutoff': [3, 6, 9, 12, 15, 18], |
| | }, |
| | 'flatten_distribution': { |
| | 'temperature': [0.1, 0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0, 5.0], |
| | 'dynamic_temperature': [ |
| | [0.1, 1], |
| | [0.1, 1.5], |
| | [0.1, 2], |
| | [0.1, 5], |
| | [0.5, 1], |
| | [0.5, 1.5], |
| | [0.5, 2], |
| | [0.5, 5], |
| | [0.8, 1], |
| | [0.8, 1.5], |
| | [0.8, 2], |
| | [0.8, 5], |
| | [1, 1.5], |
| | [1, 2], |
| | [1, 5] |
| | ], |
| | 'smoothing_factor': [0.2, 0.3, 0.6, 1.2] |
| | }, |
| | 'repetition': { |
| | 'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25], |
| | 'presence_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0], |
| | 'frequency_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0], |
| | }, |
| | 'other': { |
| | 'temperature_last': [True, False], |
| | } |
| | } |
| |
|
| | generate_params = default_preset() |
| | for cat in params_and_values: |
| | choices = list(params_and_values[cat].keys()) |
| | if shared.args.loader is not None: |
| | choices = [x for x in choices if loader_contains(x)] |
| |
|
| | if len(choices) > 0: |
| | choice = random.choice(choices) |
| | value = random.choice(params_and_values[cat][choice]) |
| | if choice == 'dynamic_temperature': |
| | generate_params['dynamic_temperature'] = True |
| | generate_params['dynatemp_low'] = value[0] |
| | generate_params['dynatemp_high'] = value[1] |
| | else: |
| | generate_params[choice] = value |
| |
|
| | state.update(generate_params) |
| | logger.info("GENERATED_PRESET=") |
| | pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(state)) |
| | return state, *[generate_params[k] for k in presets_params()] |
| |
|
| |
|
| | def loader_contains(sampler): |
| | if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]: |
| | return True |
| | else: |
| | return sampler in loaders_samplers[shared.args.loader] |
| |
|
| |
|
| | def remove_defaults(state): |
| | defaults = default_preset() |
| | data = {k: state[k] for k in presets_params()} |
| |
|
| | for k in list(data.keys()): |
| | if data[k] == defaults[k]: |
| | del data[k] |
| |
|
| | return data |
| |
|
| |
|
| | def generate_preset_yaml(state): |
| | data = remove_defaults(state) |
| | return yaml.dump(data, sort_keys=False) |
| |
|