| import os |
| from copy import deepcopy |
| from typing import List, Optional |
|
|
| from openai import OpenAI |
|
|
| from swift.llm.infer.protocol import InferRequest, RequestConfig |
| from swift.llm.sampling.vanilla_sampler import VanillaSampler |
| from .utils import get_messages_md5 |
|
|
|
|
| class OpenAI_Engine(): |
|
|
| def __init__( |
| self, |
| model: str, |
| stream: bool = False, |
| base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1', |
| api_key: str = '', |
| **kwargs, |
| ): |
| self.model = model |
| self.stream = stream |
| self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs) |
|
|
| def infer( |
| self, |
| infer_requests: List[InferRequest], |
| request_config: Optional[RequestConfig] = None, |
| ): |
| resp_contents = [] |
| for infer_request in infer_requests: |
| completion = self.client.chat.completions.create( |
| model=self.model, |
| messages=infer_request['messages'], |
| temperature=request_config.temperature, |
| top_p=request_config.top_p, |
| max_tokens=request_config.max_tokens, |
| stream=self.stream, |
| ) |
| if self.stream: |
| reasoning_content = '' |
| content = '' |
| for chunk in completion: |
| chunk_choices = chunk.choices |
| if len(chunk_choices) == 0: |
| continue |
| reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr( |
| chunk_choices[0].delta, 'reasoning_content') else '' |
| answer_chunk = chunk_choices[0].delta.content |
| if reasoning_chunk: |
| reasoning_content += reasoning_chunk |
| elif answer_chunk: |
| content += answer_chunk |
| else: |
| if hasattr(completion.choices[0].message, 'reasoning_content'): |
| reasoning_content = completion.choices[0].message.reasoning_content |
| content = completion.choices[0].message.content |
| assert len(content) > 0, 'Empty completion' |
| if reasoning_content: |
| resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>' |
| else: |
| resp_content = content |
| resp_contents.append(resp_content) |
|
|
| return resp_contents |
|
|
|
|
| class DistillSampler(VanillaSampler): |
|
|
| def __init__(self, *args, **kwargs): |
| super(VanillaSampler, self).__init__(*args, **kwargs) |
| assert self.args.sampler_engine == 'client' |
| _Engine = OpenAI_Engine |
| self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs) |
| self.infer_engine.strict = False |
| self.caches = self.read_cache() |
|
|
| def _prepare_model_tokenizer(self): |
| pass |
|
|
| def _prepare_template(self): |
| pass |
|
|
| def extract_choice(self, resp): |
| message = resp.choices[0].message |
| if hasattr(message, 'reasoning_content'): |
| reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>' |
| else: |
| reps_content = message.content |
| return reps_content |
|
|
| def generate(self, data): |
| resp_all = [] |
| infer_requests = [] |
| sent = 0 |
| rows = self.convert_data_to_rows(data) |
| for idx, row in enumerate(rows): |
| row = deepcopy(row) |
| messages = row['messages'] |
| uuid = get_messages_md5(row) |
| if uuid in self.caches: |
| choices = self.caches[uuid]['choices'] |
| if len(choices) == self.args.num_return_sequences: |
| continue |
| if self.args.system: |
| if messages[0]['role'] == 'system': |
| messages[0]['content'] = self.args.system |
| else: |
| messages.insert(0, {'role': 'system', 'content': self.args.system}) |
| if messages[-1]['role'] == 'assistant': |
| messages = messages[:-1] |
|
|
| row['messages'] = messages |
| infer_request = row |
| for i in range(self.args.num_return_sequences): |
| infer_requests.append(deepcopy(infer_request)) |
| sent += 1 |
|
|
| request_config = RequestConfig( |
| max_tokens=self.args.max_new_tokens, |
| temperature=self.args.temperature, |
| top_k=self.args.top_k, |
| top_p=self.args.top_p, |
| ) |
|
|
| resp_list = [] |
| if len(infer_requests) > 0: |
| resp_list = self.infer_engine.infer(infer_requests, request_config=request_config) |
|
|
| _cur = 0 |
| for idx, row in enumerate(rows): |
| row = deepcopy(row) |
| uuid = get_messages_md5(row) |
| if uuid in self.caches: |
| choices = self.caches[uuid]['choices'] |
| if len(choices) == self.args.num_return_sequences: |
| row['choices'] = choices |
| resp_all.append(row) |
| continue |
|
|
| resps = row |
| resps['choices'] = [] |
| for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)): |
| resps['choices'].append(resp_list[j]) |
| resp_all.append(resps) |
| _cur += 1 |
| return resp_all |
|
|