interactSpeech / swift /llm /sampling /distill_sampler.py
Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
raw
history blame
5.48 kB
import os
from copy import deepcopy
from typing import List, Optional
from openai import OpenAI
from swift.llm.infer.protocol import InferRequest, RequestConfig
from swift.llm.sampling.vanilla_sampler import VanillaSampler
from .utils import get_messages_md5
class OpenAI_Engine():
def __init__(
self,
model: str,
stream: bool = False,
base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1',
api_key: str = '',
**kwargs,
):
self.model = model
self.stream = stream
self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs)
def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
):
resp_contents = []
for infer_request in infer_requests:
completion = self.client.chat.completions.create(
model=self.model,
messages=infer_request['messages'],
temperature=request_config.temperature,
top_p=request_config.top_p,
max_tokens=request_config.max_tokens,
stream=self.stream,
)
if self.stream:
reasoning_content = ''
content = ''
for chunk in completion:
chunk_choices = chunk.choices
if len(chunk_choices) == 0:
continue
reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr(
chunk_choices[0].delta, 'reasoning_content') else ''
answer_chunk = chunk_choices[0].delta.content
if reasoning_chunk:
reasoning_content += reasoning_chunk
elif answer_chunk:
content += answer_chunk
else:
if hasattr(completion.choices[0].message, 'reasoning_content'):
reasoning_content = completion.choices[0].message.reasoning_content
content = completion.choices[0].message.content
assert len(content) > 0, 'Empty completion'
if reasoning_content:
resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>'
else:
resp_content = content
resp_contents.append(resp_content)
return resp_contents
class DistillSampler(VanillaSampler):
def __init__(self, *args, **kwargs):
super(VanillaSampler, self).__init__(*args, **kwargs)
assert self.args.sampler_engine == 'client'
_Engine = OpenAI_Engine
self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs)
self.infer_engine.strict = False
self.caches = self.read_cache()
def _prepare_model_tokenizer(self):
pass
def _prepare_template(self):
pass
def extract_choice(self, resp):
message = resp.choices[0].message
if hasattr(message, 'reasoning_content'):
reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>'
else:
reps_content = message.content
return reps_content
def generate(self, data):
resp_all = []
infer_requests = []
sent = 0
rows = self.convert_data_to_rows(data)
for idx, row in enumerate(rows):
row = deepcopy(row)
messages = row['messages']
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
continue
if self.args.system:
if messages[0]['role'] == 'system':
messages[0]['content'] = self.args.system
else:
messages.insert(0, {'role': 'system', 'content': self.args.system})
if messages[-1]['role'] == 'assistant':
messages = messages[:-1]
row['messages'] = messages
infer_request = row
for i in range(self.args.num_return_sequences):
infer_requests.append(deepcopy(infer_request))
sent += 1
request_config = RequestConfig(
max_tokens=self.args.max_new_tokens,
temperature=self.args.temperature,
top_k=self.args.top_k,
top_p=self.args.top_p,
)
resp_list = []
if len(infer_requests) > 0:
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
_cur = 0
for idx, row in enumerate(rows):
row = deepcopy(row)
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
row['choices'] = choices
resp_all.append(row)
continue
resps = row
resps['choices'] = []
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
resps['choices'].append(resp_list[j])
resp_all.append(resps)
_cur += 1
return resp_all