Spaces:
Running
Running
| def get_model(model_name, adapted_component, adaptor_class, num_steers, rank, | |
| epsilon, init_var, low_resource_mode): | |
| if model_name.startswith("EleutherAI/gpt-neo") or \ | |
| model_name.startswith("gpt2"): | |
| from lm_steer.models.model_gpt_neo import Switching_GPTNeoModel | |
| model = Switching_GPTNeoModel( | |
| model_name, adapted_component, adaptor_class, num_steers, rank, | |
| epsilon, init_var, low_resource_mode) | |
| return model, model.tokenizer | |
| elif model_name.startswith("lora-gpt2"): | |
| from lm_steer.models.model_lora_gpt_neo import LORA_GPTNeoModel | |
| model = LORA_GPTNeoModel(model_name, rank, epsilon) | |
| return model, model.tokenizer | |
| elif model_name.startswith("embedding_tuning"): | |
| from lm_steer.models.model_embedding_tuning_gpt_neo import \ | |
| EmbeddingTuning_GPTNeoModel | |
| model = EmbeddingTuning_GPTNeoModel(model_name) | |
| return model, model.tokenizer | |
| elif model_name.startswith("prefix-gpt2"): | |
| from lm_steer.models.model_prefix_gpt_neo import PREFIX_GPTNeoModel | |
| model = PREFIX_GPTNeoModel(model_name) | |
| return model, model.tokenizer | |
| elif model_name.startswith("EleutherAI/pythia"): | |
| from lm_steer.models.model_gpt_neox import Switching_GPTNeoXModel | |
| model = Switching_GPTNeoXModel( | |
| model_name, adapted_component, adaptor_class, num_steers, rank, | |
| epsilon, init_var, low_resource_mode) | |
| return model, model.tokenizer | |
| elif model_name.startswith("EleutherAI/gpt-j"): | |
| from lm_steer.models.model_gpt_j import Switching_GPTJModel | |
| model = Switching_GPTJModel( | |
| model_name, adapted_component, adaptor_class, num_steers, rank, | |
| epsilon, init_var, low_resource_mode) | |
| return model, model.tokenizer | |
| elif model_name.startswith("microsoft/DialoGPT"): | |
| from lm_steer.models.model_dialogpt import Switching_DialoGPTModel | |
| model = Switching_DialoGPTModel( | |
| model_name, adapted_component, adaptor_class, num_steers, rank, | |
| epsilon, init_var, low_resource_mode) | |
| return model, model.tokenizer | |
| else: | |
| raise NotImplementedError() | |