| import re |
| from typing import List, Tuple |
|
|
| import invokeai.backend.util.logging as logger |
| from invokeai.app.services.model_records import UnknownModelException |
| from invokeai.app.services.shared.invocation_context import InvocationContext |
| from invokeai.backend.model_manager.config import BaseModelType, ModelType |
| from invokeai.backend.textual_inversion import TextualInversionModelRaw |
|
|
|
|
| def extract_ti_triggers_from_prompt(prompt: str) -> List[str]: |
| ti_triggers: List[str] = [] |
| for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): |
| ti_triggers.append(str(trigger)) |
| return ti_triggers |
|
|
|
|
| def generate_ti_list( |
| prompt: str, base: BaseModelType, context: InvocationContext |
| ) -> List[Tuple[str, TextualInversionModelRaw]]: |
| ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] |
| for trigger in extract_ti_triggers_from_prompt(prompt): |
| name_or_key = trigger[1:-1] |
| try: |
| loaded_model = context.models.load(name_or_key) |
| model = loaded_model.model |
| assert isinstance(model, TextualInversionModelRaw) |
| assert loaded_model.config.base == base |
| ti_list.append((name_or_key, model)) |
| except UnknownModelException: |
| try: |
| loaded_model = context.models.load_by_attrs( |
| name=name_or_key, base=base, type=ModelType.TextualInversion |
| ) |
| model = loaded_model.model |
| assert isinstance(model, TextualInversionModelRaw) |
| assert loaded_model.config.base == base |
| ti_list.append((name_or_key, model)) |
| except UnknownModelException: |
| pass |
| except ValueError: |
| logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') |
| except AssertionError: |
| logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph') |
| except Exception: |
| logger.warning(f'Failed to load TI model for trigger: "{trigger}"') |
| return ti_list |
|
|