|
|
|
|
|
import os |
|
|
from copy import deepcopy |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import json |
|
|
import numpy as np |
|
|
|
|
|
from swift.llm import RequestConfig |
|
|
from swift.llm.sampling.base import Sampler |
|
|
from swift.llm.template.template_inputs import InferRequest |
|
|
from swift.utils import get_logger |
|
|
from .utils import get_messages_md5, get_reward |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class VanillaSampler(Sampler): |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
if self.args.sampler_engine == 'pt': |
|
|
from swift.llm import PtEngine |
|
|
_Engine = PtEngine |
|
|
elif self.args.sampler_engine == 'vllm': |
|
|
from swift.llm import VllmEngine |
|
|
_Engine = VllmEngine |
|
|
elif self.args.sampler_engine == 'lmdeploy': |
|
|
from swift.llm import LmdeployEngine |
|
|
_Engine = LmdeployEngine |
|
|
elif self.args.sampler_engine == 'no': |
|
|
_Engine = None |
|
|
else: |
|
|
raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}') |
|
|
self.infer_engine = None |
|
|
if _Engine: |
|
|
self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs) |
|
|
self.infer_engine.default_template = self.template |
|
|
self.infer_engine.strict = False |
|
|
self.caches = self.read_cache() |
|
|
|
|
|
def read_cache(self): |
|
|
cache_files = self.args.cache_files |
|
|
caches = {} |
|
|
for file in cache_files: |
|
|
if not os.path.exists(file): |
|
|
logger.warning(f'Cache file does not exist: {file}') |
|
|
continue |
|
|
with open(file, 'r') as f: |
|
|
for line in f.readlines(): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
content = json.loads(line) |
|
|
uuid = content['id'] |
|
|
messages = content['messages'] |
|
|
if uuid not in caches: |
|
|
caches[uuid] = {'choices': []} |
|
|
assert messages[-1]['role'] == 'assistant' |
|
|
caches[uuid]['choices'].append(messages[-1]['content']) |
|
|
return caches |
|
|
|
|
|
@staticmethod |
|
|
def convert_data_to_rows(data): |
|
|
rows = [] |
|
|
key = list(data.keys())[0] |
|
|
data_len = len(data[key]) |
|
|
for idx in range(data_len): |
|
|
row = {key: data[key][idx] for key in data} |
|
|
if row.get('images') and 'bytes' in row['images'][0]: |
|
|
row['images'] = [img['path'] for img in row['images']] |
|
|
rows.append(row) |
|
|
VanillaSampler.check_row_valid(rows) |
|
|
return rows |
|
|
|
|
|
@staticmethod |
|
|
def check_row_valid(rows): |
|
|
for row in rows: |
|
|
assert not row.get('images') or all([isinstance(img, str) and img for img in row['images']]) |
|
|
assert not row.get('videos') or all([isinstance(video, str) and video for video in row['videos']]) |
|
|
assert not row.get('audios') or all([isinstance(audio, str) and audio for audio in row['audios']]) |
|
|
|
|
|
def generate(self, data): |
|
|
resp_all = [] |
|
|
infer_requests = [] |
|
|
sent = 0 |
|
|
rows = self.convert_data_to_rows(data) |
|
|
for idx, row in enumerate(rows): |
|
|
row = deepcopy(row) |
|
|
messages = row['messages'] |
|
|
uuid = get_messages_md5(row) |
|
|
if uuid in self.caches: |
|
|
choices = self.caches[uuid]['choices'] |
|
|
if len(choices) == self.args.num_return_sequences: |
|
|
continue |
|
|
if self.args.system: |
|
|
if messages[0]['role'] == 'system': |
|
|
messages[0]['content'] = self.args.system |
|
|
else: |
|
|
messages.insert(0, {'role': 'system', 'content': self.args.system}) |
|
|
if messages[-1]['role'] == 'assistant': |
|
|
messages = messages[:-1] |
|
|
|
|
|
row['messages'] = messages |
|
|
infer_request = row |
|
|
for i in range(self.args.num_return_sequences): |
|
|
infer_requests.append(deepcopy(infer_request)) |
|
|
sent += 1 |
|
|
|
|
|
request_config = RequestConfig( |
|
|
max_tokens=self.args.max_new_tokens, |
|
|
temperature=self.args.temperature, |
|
|
top_k=self.args.top_k, |
|
|
top_p=self.args.top_p, |
|
|
) |
|
|
|
|
|
resp_list = [] |
|
|
if len(infer_requests) > 0: |
|
|
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config) |
|
|
|
|
|
_cur = 0 |
|
|
for idx, row in enumerate(rows): |
|
|
row = deepcopy(row) |
|
|
uuid = get_messages_md5(row) |
|
|
if uuid in self.caches: |
|
|
choices = self.caches[uuid]['choices'] |
|
|
if len(choices) == self.args.num_return_sequences: |
|
|
row['choices'] = choices |
|
|
resp_all.append(row) |
|
|
continue |
|
|
|
|
|
resps = row |
|
|
resps['choices'] = [] |
|
|
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)): |
|
|
if not isinstance(resp_list[j], Exception): |
|
|
resps['choices'].append(resp_list[j].choices[0].message.content) |
|
|
if resps['choices']: |
|
|
resp_all.append(resps) |
|
|
_cur += 1 |
|
|
return resp_all |
|
|
|
|
|
def do_sample(self, data): |
|
|
generated = [] |
|
|
resp_all = self.generate(data) |
|
|
for i, resps in enumerate(resp_all): |
|
|
choices = resps['choices'] |
|
|
messages = resps['messages'] |
|
|
uuid = get_messages_md5(resps) |
|
|
assert messages[-1]['role'] == 'assistant' |
|
|
ground_truth = messages[-1]['content'] |
|
|
|
|
|
infer_requests = [] |
|
|
for decoded in choices: |
|
|
_resps = deepcopy(resps) |
|
|
_resps['messages'][-1]['content'] = decoded |
|
|
infer_requests.append(_resps) |
|
|
|
|
|
_resps = deepcopy(resps) |
|
|
_resps['messages'][-1]['content'] = ground_truth |
|
|
infer_requests.append(_resps) |
|
|
if self.orm_model is not None: |
|
|
orm_score, _orm_mask = get_reward( |
|
|
self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0) |
|
|
else: |
|
|
orm_score = np.array([1.0] * len(infer_requests)) |
|
|
_orm_mask = np.array([True] * len(infer_requests)) |
|
|
if self.prm_model is not None: |
|
|
prm_score, _prm_mask = get_reward( |
|
|
self.prm_model, |
|
|
infer_requests, |
|
|
ground_truths=[ground_truth] * len(infer_requests), |
|
|
threshold=self.args.prm_threshold) |
|
|
else: |
|
|
prm_score = np.array([1.0] * len(infer_requests)) |
|
|
_prm_mask = np.array([True] * len(infer_requests)) |
|
|
|
|
|
_mask = _orm_mask & _prm_mask |
|
|
if not any(_mask): |
|
|
continue |
|
|
|
|
|
choices.append(ground_truth) |
|
|
choices = np.array(choices) |
|
|
|
|
|
if self.orm_model is None and self.prm_model is None: |
|
|
positives = choices[:-1] |
|
|
for positive in positives: |
|
|
_resps = deepcopy(resps) |
|
|
_resps.pop('choices', None) |
|
|
_resps['id'] = uuid |
|
|
_resps['messages'][-1]['content'] = str(positive) |
|
|
generated.append(json.dumps(_resps, ensure_ascii=False) + '\n') |
|
|
else: |
|
|
score = np.array(prm_score) + np.array(orm_score * 10) |
|
|
sorted_indices = np.argsort(score)[::-1] |
|
|
pos_indexes = sorted_indices[0:self.args.n_best_to_keep] |
|
|
pos_indexes = [i for i in pos_indexes if _mask[i]] |
|
|
neg_index = sorted_indices[-1] |
|
|
logger.info( |
|
|
f'orm:{orm_score}, prm:{prm_score}, positive index: {pos_indexes}, negative index: {neg_index}') |
|
|
if self.args.easy_query_threshold is not None and sum([score > 0 for score in orm_score]) - 1 >= int( |
|
|
self.args.num_return_sequences * self.args.easy_query_threshold): |
|
|
continue |
|
|
if len(pos_indexes) > 0: |
|
|
positives = choices[pos_indexes] |
|
|
negative = choices[neg_index] |
|
|
for positive in positives: |
|
|
_resps = deepcopy(resps) |
|
|
messages = deepcopy(messages) |
|
|
_resps.pop('choices', None) |
|
|
_resps['messages'][-1]['content'] = str(positive) |
|
|
_resps['rejected_response'] = str(negative) |
|
|
_resps['id'] = uuid |
|
|
generated.append(json.dumps(_resps, ensure_ascii=False) + '\n') |
|
|
return generated |
|
|
|