Spaces:
Sleeping
Sleeping
| from .base_handler import ModelHandler | |
| from .nlp_models.sequence_classification_handler import SequenceClassificationHandler | |
| from .nlp_models.question_answering_handler import QuestionAnsweringHandler | |
| from .nlp_models.token_classification_handler import TokenClassificationHandler | |
| from .nlp_models.causal_lm_handler import CausalLMHandler | |
| from .nlp_models.embedding_model_handler import EmbeddingModelHandler | |
| from .audio_models.whisper_handler import WhisperHandler | |
| from .nlp_models.masked_lm_handler import MaskedLMHandler | |
| from .nlp_models.seq2seq_lm_handler import Seq2SeqLMHandler | |
| from .nlp_models.multiple_choice_handler import MultipleChoiceHandler | |
| from .img_models.image_classification_handler import ImageClassificationHandler | |
| from transformers import ( | |
| AutoModel, | |
| AutoModelForTokenClassification, | |
| AutoModelForSequenceClassification, | |
| AutoModelForQuestionAnswering, | |
| AutoModelForCausalLM, | |
| AutoModelForMaskedLM, | |
| AutoModelForSeq2SeqLM, | |
| AutoModelForMultipleChoice, | |
| ) | |
| TASK_CONFIGS = { | |
| "embedding": { | |
| "model_class": AutoModel, | |
| "handler_class": EmbeddingModelHandler, | |
| "example_text": "Hey, I am feeling way to good to be true.", | |
| }, | |
| "ner": { | |
| "model_class": AutoModelForTokenClassification, | |
| "handler_class": TokenClassificationHandler, | |
| "example_text": "John works at Google in New York as a software engineer.", | |
| }, | |
| "text_classification": { | |
| "model_class": AutoModelForSequenceClassification, | |
| "handler_class": SequenceClassificationHandler, | |
| "example_text": "This movie was great and I loved it.", | |
| }, | |
| "question_answering": { | |
| "model_class": AutoModelForQuestionAnswering, | |
| "handler_class": QuestionAnsweringHandler, | |
| "example_text": "The pyramids were built in ancient Egypt. QUES: Where were the pyramids built?", | |
| }, | |
| "causal_lm": { | |
| "model_class": AutoModelForCausalLM, | |
| "handler_class": CausalLMHandler, | |
| "example_text": "Once upon a time, there was ", | |
| }, | |
| "mask_lm": { | |
| "model_class": AutoModelForMaskedLM, | |
| "handler_class": MaskedLMHandler, | |
| "example_text": "The quick brown [MASK] jumps over the lazy dog.", | |
| }, | |
| "seq2seq_lm": { | |
| "model_class": AutoModelForSeq2SeqLM, | |
| "handler_class": Seq2SeqLMHandler, | |
| "example_text": "Translate English to French: The house is wonderful.", | |
| }, | |
| "multiple_choice": { | |
| "model_class": AutoModelForMultipleChoice, | |
| "handler_class": MultipleChoiceHandler, | |
| "example_text": "What is the capital of France? (A) Paris (B) London (C) Berlin (D) Rome", | |
| }, | |
| "whisper_finetuning": { | |
| "model_class": None, # Not implemented | |
| "handler_class": WhisperHandler, | |
| "example_text": "!!!!!NOT IMPLEMENTED!!!!!", | |
| }, | |
| "image_classification": { | |
| "model_class": None, # Not implemented | |
| "handler_class": ImageClassificationHandler, | |
| "example_text": "!!!!!NOT IMPLEMENTED!!!!!", | |
| }, | |
| } | |
| def get_model_handler(task: str, model_name: str, quantization_type: str, test_text: str): | |
| task_config = TASK_CONFIGS.get(task) | |
| if not task_config: | |
| raise ValueError(f"No configuration found for task: {task}") | |
| handler_class = task_config["handler_class"] | |
| model_class = task_config["model_class"] | |
| return handler_class(model_name, model_class, quantization_type, test_text) |