File size: 5,478 Bytes
cb2428f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
from copy import deepcopy
from typing import List, Optional
from openai import OpenAI
from swift.llm.infer.protocol import InferRequest, RequestConfig
from swift.llm.sampling.vanilla_sampler import VanillaSampler
from .utils import get_messages_md5
class OpenAI_Engine():
def __init__(
self,
model: str,
stream: bool = False,
base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1',
api_key: str = '',
**kwargs,
):
self.model = model
self.stream = stream
self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs)
def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
):
resp_contents = []
for infer_request in infer_requests:
completion = self.client.chat.completions.create(
model=self.model,
messages=infer_request['messages'],
temperature=request_config.temperature,
top_p=request_config.top_p,
max_tokens=request_config.max_tokens,
stream=self.stream,
)
if self.stream:
reasoning_content = ''
content = ''
for chunk in completion:
chunk_choices = chunk.choices
if len(chunk_choices) == 0:
continue
reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr(
chunk_choices[0].delta, 'reasoning_content') else ''
answer_chunk = chunk_choices[0].delta.content
if reasoning_chunk:
reasoning_content += reasoning_chunk
elif answer_chunk:
content += answer_chunk
else:
if hasattr(completion.choices[0].message, 'reasoning_content'):
reasoning_content = completion.choices[0].message.reasoning_content
content = completion.choices[0].message.content
assert len(content) > 0, 'Empty completion'
if reasoning_content:
resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>'
else:
resp_content = content
resp_contents.append(resp_content)
return resp_contents
class DistillSampler(VanillaSampler):
def __init__(self, *args, **kwargs):
super(VanillaSampler, self).__init__(*args, **kwargs)
assert self.args.sampler_engine == 'client'
_Engine = OpenAI_Engine
self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs)
self.infer_engine.strict = False
self.caches = self.read_cache()
def _prepare_model_tokenizer(self):
pass
def _prepare_template(self):
pass
def extract_choice(self, resp):
message = resp.choices[0].message
if hasattr(message, 'reasoning_content'):
reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>'
else:
reps_content = message.content
return reps_content
def generate(self, data):
resp_all = []
infer_requests = []
sent = 0
rows = self.convert_data_to_rows(data)
for idx, row in enumerate(rows):
row = deepcopy(row)
messages = row['messages']
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
continue
if self.args.system:
if messages[0]['role'] == 'system':
messages[0]['content'] = self.args.system
else:
messages.insert(0, {'role': 'system', 'content': self.args.system})
if messages[-1]['role'] == 'assistant':
messages = messages[:-1]
row['messages'] = messages
infer_request = row
for i in range(self.args.num_return_sequences):
infer_requests.append(deepcopy(infer_request))
sent += 1
request_config = RequestConfig(
max_tokens=self.args.max_new_tokens,
temperature=self.args.temperature,
top_k=self.args.top_k,
top_p=self.args.top_p,
)
resp_list = []
if len(infer_requests) > 0:
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
_cur = 0
for idx, row in enumerate(rows):
row = deepcopy(row)
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
row['choices'] = choices
resp_all.append(row)
continue
resps = row
resps['choices'] = []
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
resps['choices'].append(resp_list[j])
resp_all.append(resps)
_cur += 1
return resp_all
|