| |
| import unittest |
| from typing import Any, Dict, Optional |
|
|
| import torch |
|
|
| from swift.llm import (DatasetMeta, InferRequest, Model, ModelGroup, ModelMeta, PtEngine, RequestConfig, |
| ResponsePreprocessor, TemplateMeta, get_model_tokenizer_with_flash_attn, load_dataset, |
| register_dataset, register_model, register_template) |
|
|
|
|
| class CustomPreprocessor(ResponsePreprocessor): |
| prompt = """Task: Based on the given two sentences, provide a similarity score between 0.0 and 5.0. |
| Sentence 1: {text1} |
| Sentence 2: {text2} |
| Similarity score: """ |
|
|
| def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
| return super().preprocess({ |
| 'query': self.prompt.format(text1=row['text1'], text2=row['text2']), |
| 'response': f"{row['label']:.1f}" |
| }) |
|
|
|
|
| register_dataset( |
| DatasetMeta( |
| ms_dataset_id='swift/stsb', |
| hf_dataset_id='SetFit/stsb', |
| preprocess_func=CustomPreprocessor(), |
| )) |
|
|
| register_template( |
| TemplateMeta( |
| template_type='custom', |
| prefix=['<extra_id_0>System\n{{SYSTEM}}\n'], |
| prompt=['<extra_id_1>User\n{{QUERY}}\n<extra_id_1>Assistant\n'], |
| chat_sep=['\n'])) |
|
|
| register_model( |
| ModelMeta( |
| model_type='custom', |
| model_groups=[ |
| ModelGroup([Model('AI-ModelScope/Nemotron-Mini-4B-Instruct', 'nvidia/Nemotron-Mini-4B-Instruct')]) |
| ], |
| template='custom', |
| get_function=get_model_tokenizer_with_flash_attn, |
| ignore_patterns=['nemo'])) |
|
|
|
|
| class TestCustom(unittest.TestCase): |
|
|
| def test_custom_model(self): |
| infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}]) |
| request_config = RequestConfig(max_tokens=512, temperature=0) |
| engine = PtEngine('AI-ModelScope/Nemotron-Mini-4B-Instruct', torch.float16) |
| response = engine.infer([infer_request], request_config) |
| swift_response = response[0].choices[0].message.content |
|
|
| engine.default_template.template_backend = 'jinja' |
| response = engine.infer([infer_request], request_config) |
| jinja_response = response[0].choices[0].message.content |
| assert swift_response == jinja_response, (f'swift_response: {swift_response}\njinja_response: {jinja_response}') |
| print(f'response: {swift_response}') |
|
|
| def test_custom_dataset(self): |
| dataset = load_dataset(['swift/stsb'])[0] |
| assert len(dataset) == 5749 |
| assert list(dataset[0].keys()) == ['messages'] |
| print(f'dataset: {dataset}') |
| print(f'dataset[0]: {dataset[0]}') |
|
|
|
|
| if __name__ == '__main__': |
| unittest.main() |
|
|