interactSpeech / swift /llm /sampling /distill_sampler.py
Student0809's picture
Add files using upload-large-folder tool
cb2428f verified
import os
from copy import deepcopy
from typing import List, Optional
from openai import OpenAI
from swift.llm.infer.protocol import InferRequest, RequestConfig
from swift.llm.sampling.vanilla_sampler import VanillaSampler
from .utils import get_messages_md5
class OpenAI_Engine():
def __init__(
self,
model: str,
stream: bool = False,
base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1',
api_key: str = '',
**kwargs,
):
self.model = model
self.stream = stream
self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs)
def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
):
resp_contents = []
for infer_request in infer_requests:
completion = self.client.chat.completions.create(
model=self.model,
messages=infer_request['messages'],
temperature=request_config.temperature,
top_p=request_config.top_p,
max_tokens=request_config.max_tokens,
stream=self.stream,
)
if self.stream:
reasoning_content = ''
content = ''
for chunk in completion:
chunk_choices = chunk.choices
if len(chunk_choices) == 0:
continue
reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr(
chunk_choices[0].delta, 'reasoning_content') else ''
answer_chunk = chunk_choices[0].delta.content
if reasoning_chunk:
reasoning_content += reasoning_chunk
elif answer_chunk:
content += answer_chunk
else:
if hasattr(completion.choices[0].message, 'reasoning_content'):
reasoning_content = completion.choices[0].message.reasoning_content
content = completion.choices[0].message.content
assert len(content) > 0, 'Empty completion'
if reasoning_content:
resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>'
else:
resp_content = content
resp_contents.append(resp_content)
return resp_contents
class DistillSampler(VanillaSampler):
def __init__(self, *args, **kwargs):
super(VanillaSampler, self).__init__(*args, **kwargs)
assert self.args.sampler_engine == 'client'
_Engine = OpenAI_Engine
self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs)
self.infer_engine.strict = False
self.caches = self.read_cache()
def _prepare_model_tokenizer(self):
pass
def _prepare_template(self):
pass
def extract_choice(self, resp):
message = resp.choices[0].message
if hasattr(message, 'reasoning_content'):
reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>'
else:
reps_content = message.content
return reps_content
def generate(self, data):
resp_all = []
infer_requests = []
sent = 0
rows = self.convert_data_to_rows(data)
for idx, row in enumerate(rows):
row = deepcopy(row)
messages = row['messages']
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
continue
if self.args.system:
if messages[0]['role'] == 'system':
messages[0]['content'] = self.args.system
else:
messages.insert(0, {'role': 'system', 'content': self.args.system})
if messages[-1]['role'] == 'assistant':
messages = messages[:-1]
row['messages'] = messages
infer_request = row
for i in range(self.args.num_return_sequences):
infer_requests.append(deepcopy(infer_request))
sent += 1
request_config = RequestConfig(
max_tokens=self.args.max_new_tokens,
temperature=self.args.temperature,
top_k=self.args.top_k,
top_p=self.args.top_p,
)
resp_list = []
if len(infer_requests) > 0:
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
_cur = 0
for idx, row in enumerate(rows):
row = deepcopy(row)
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
row['choices'] = choices
resp_all.append(row)
continue
resps = row
resps['choices'] = []
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
resps['choices'].append(resp_list[j])
resp_all.append(resps)
_cur += 1
return resp_all