| import os | |
| from copy import deepcopy | |
| from typing import List, Optional | |
| from openai import OpenAI | |
| from swift.llm.infer.protocol import InferRequest, RequestConfig | |
| from swift.llm.sampling.vanilla_sampler import VanillaSampler | |
| from .utils import get_messages_md5 | |
| class OpenAI_Engine(): | |
| def __init__( | |
| self, | |
| model: str, | |
| stream: bool = False, | |
| base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1', | |
| api_key: str = '', | |
| **kwargs, | |
| ): | |
| self.model = model | |
| self.stream = stream | |
| self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs) | |
| def infer( | |
| self, | |
| infer_requests: List[InferRequest], | |
| request_config: Optional[RequestConfig] = None, | |
| ): | |
| resp_contents = [] | |
| for infer_request in infer_requests: | |
| completion = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=infer_request['messages'], | |
| temperature=request_config.temperature, | |
| top_p=request_config.top_p, | |
| max_tokens=request_config.max_tokens, | |
| stream=self.stream, | |
| ) | |
| if self.stream: | |
| reasoning_content = '' | |
| content = '' | |
| for chunk in completion: | |
| chunk_choices = chunk.choices | |
| if len(chunk_choices) == 0: | |
| continue | |
| reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr( | |
| chunk_choices[0].delta, 'reasoning_content') else '' | |
| answer_chunk = chunk_choices[0].delta.content | |
| if reasoning_chunk: | |
| reasoning_content += reasoning_chunk | |
| elif answer_chunk: | |
| content += answer_chunk | |
| else: | |
| if hasattr(completion.choices[0].message, 'reasoning_content'): | |
| reasoning_content = completion.choices[0].message.reasoning_content | |
| content = completion.choices[0].message.content | |
| assert len(content) > 0, 'Empty completion' | |
| if reasoning_content: | |
| resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>' | |
| else: | |
| resp_content = content | |
| resp_contents.append(resp_content) | |
| return resp_contents | |
| class DistillSampler(VanillaSampler): | |
| def __init__(self, *args, **kwargs): | |
| super(VanillaSampler, self).__init__(*args, **kwargs) | |
| assert self.args.sampler_engine == 'client' | |
| _Engine = OpenAI_Engine | |
| self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs) | |
| self.infer_engine.strict = False | |
| self.caches = self.read_cache() | |
| def _prepare_model_tokenizer(self): | |
| pass | |
| def _prepare_template(self): | |
| pass | |
| def extract_choice(self, resp): | |
| message = resp.choices[0].message | |
| if hasattr(message, 'reasoning_content'): | |
| reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>' | |
| else: | |
| reps_content = message.content | |
| return reps_content | |
| def generate(self, data): | |
| resp_all = [] | |
| infer_requests = [] | |
| sent = 0 | |
| rows = self.convert_data_to_rows(data) | |
| for idx, row in enumerate(rows): | |
| row = deepcopy(row) | |
| messages = row['messages'] | |
| uuid = get_messages_md5(row) | |
| if uuid in self.caches: | |
| choices = self.caches[uuid]['choices'] | |
| if len(choices) == self.args.num_return_sequences: | |
| continue | |
| if self.args.system: | |
| if messages[0]['role'] == 'system': | |
| messages[0]['content'] = self.args.system | |
| else: | |
| messages.insert(0, {'role': 'system', 'content': self.args.system}) | |
| if messages[-1]['role'] == 'assistant': | |
| messages = messages[:-1] | |
| row['messages'] = messages | |
| infer_request = row | |
| for i in range(self.args.num_return_sequences): | |
| infer_requests.append(deepcopy(infer_request)) | |
| sent += 1 | |
| request_config = RequestConfig( | |
| max_tokens=self.args.max_new_tokens, | |
| temperature=self.args.temperature, | |
| top_k=self.args.top_k, | |
| top_p=self.args.top_p, | |
| ) | |
| resp_list = [] | |
| if len(infer_requests) > 0: | |
| resp_list = self.infer_engine.infer(infer_requests, request_config=request_config) | |
| _cur = 0 | |
| for idx, row in enumerate(rows): | |
| row = deepcopy(row) | |
| uuid = get_messages_md5(row) | |
| if uuid in self.caches: | |
| choices = self.caches[uuid]['choices'] | |
| if len(choices) == self.args.num_return_sequences: | |
| row['choices'] = choices | |
| resp_all.append(row) | |
| continue | |
| resps = row | |
| resps['choices'] = [] | |
| for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)): | |
| resps['choices'].append(resp_list[j]) | |
| resp_all.append(resps) | |
| _cur += 1 | |
| return resp_all | |