|
|
import re
|
|
|
import numpy as np
|
|
|
import bisect
|
|
|
from dataclasses import asdict, dataclass
|
|
|
|
|
|
from llm_api import ModelConfig
|
|
|
from prompts.对齐剧情和正文 import prompt as match_plot_and_text
|
|
|
from prompts.审阅.prompt import main as prompt_review
|
|
|
from core.writer_utils import split_text_into_chunks, detect_max_edit_span, run_yield_func
|
|
|
from core.writer_utils import KeyPointMsg
|
|
|
from core.diff_utils import get_chunk_changes
|
|
|
|
|
|
|
|
|
class Chunk(dict):
|
|
|
def __init__(self, chunk_pairs: tuple[tuple[str, str, str]], source_slice: tuple[int, int], text_slice: tuple[int, int]):
|
|
|
super().__init__()
|
|
|
self['chunk_pairs'] = tuple(chunk_pairs)
|
|
|
|
|
|
if isinstance(source_slice, slice):
|
|
|
source_slice = (source_slice.start, source_slice.stop)
|
|
|
self['source_slice'] = source_slice
|
|
|
|
|
|
if isinstance(text_slice, slice):
|
|
|
text_slice = (text_slice.start, text_slice.stop)
|
|
|
assert text_slice[1] is None or text_slice[1] < 0, 'text_slice end must be None or negative'
|
|
|
self['text_slice'] = text_slice
|
|
|
|
|
|
def edit(self, x_chunk=None, y_chunk=None, text_pairs=None):
|
|
|
if x_chunk is not None:
|
|
|
text_pairs = [(x_chunk, self.y_chunk), ]
|
|
|
elif y_chunk is not None:
|
|
|
text_pairs = [(self.x_chunk, y_chunk), ]
|
|
|
else:
|
|
|
text_pairs = text_pairs
|
|
|
|
|
|
chunk_pairs = list(self['chunk_pairs'])
|
|
|
chunk_pairs[self.text_slice] = list(text_pairs)
|
|
|
|
|
|
return Chunk(chunk_pairs=tuple(chunk_pairs), source_slice=self.source_slice, text_slice=self.text_slice)
|
|
|
|
|
|
@property
|
|
|
def source_slice(self) -> slice:
|
|
|
return slice(*self['source_slice'])
|
|
|
|
|
|
@property
|
|
|
def chunk_pairs(self) -> tuple[tuple[str, str]]:
|
|
|
return self['chunk_pairs']
|
|
|
|
|
|
@property
|
|
|
def text_slice(self) -> slice:
|
|
|
return slice(*self['text_slice'])
|
|
|
|
|
|
@property
|
|
|
def text_source_slice(self) -> slice:
|
|
|
source_start = self.source_slice.start + self.text_slice.start
|
|
|
source_stop = self.source_slice.stop + (self.text_slice.stop or 0)
|
|
|
return slice(source_start, source_stop)
|
|
|
|
|
|
@property
|
|
|
def text_pairs(self) -> tuple[tuple[str, str]]:
|
|
|
return self.chunk_pairs[self.text_slice]
|
|
|
|
|
|
@property
|
|
|
def x_chunk(self) -> str:
|
|
|
return ''.join(pair[0] for pair in self.text_pairs)
|
|
|
|
|
|
@property
|
|
|
def y_chunk(self) -> str:
|
|
|
return ''.join(pair[1] for pair in self.text_pairs)
|
|
|
|
|
|
@property
|
|
|
def x_chunk_len(self) -> int:
|
|
|
return sum(len(pair[0]) for pair in self.text_pairs)
|
|
|
|
|
|
@property
|
|
|
def y_chunk_len(self) -> int:
|
|
|
return sum(len(pair[1]) for pair in self.text_pairs)
|
|
|
|
|
|
@property
|
|
|
def x_chunk_context(self) -> str:
|
|
|
return ''.join(pair[0] for pair in self.chunk_pairs)
|
|
|
|
|
|
@property
|
|
|
def y_chunk_context(self) -> str:
|
|
|
return ''.join(pair[1] for pair in self.chunk_pairs)
|
|
|
|
|
|
@property
|
|
|
def x_chunk_context_len(self) -> int:
|
|
|
return sum(len(pair[0]) for pair in self.chunk_pairs)
|
|
|
|
|
|
@property
|
|
|
def y_chunk_context_len(self) -> int:
|
|
|
return sum(len(pair[1]) for pair in self.chunk_pairs)
|
|
|
|
|
|
|
|
|
class Writer:
|
|
|
def __init__(self, xy_pairs, global_context=None, model:ModelConfig=None, sub_model:ModelConfig=None, x_chunk_length=1000, y_chunk_length=1000, max_thread_num=5):
|
|
|
self.xy_pairs = xy_pairs
|
|
|
self.global_context = global_context or {}
|
|
|
|
|
|
self.model = model
|
|
|
self.sub_model = sub_model
|
|
|
|
|
|
self.x_chunk_length = x_chunk_length
|
|
|
self.y_chunk_length = y_chunk_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.max_thread_num = max_thread_num
|
|
|
|
|
|
@property
|
|
|
def x(self):
|
|
|
return ''.join(pair[0] for pair in self.xy_pairs)
|
|
|
|
|
|
@property
|
|
|
def y(self):
|
|
|
return ''.join(pair[1] for pair in self.xy_pairs)
|
|
|
|
|
|
@property
|
|
|
def x_len(self):
|
|
|
return sum(len(pair[0]) for pair in self.xy_pairs)
|
|
|
|
|
|
@property
|
|
|
def y_len(self):
|
|
|
return sum(len(pair[1]) for pair in self.xy_pairs)
|
|
|
|
|
|
def get_model(self):
|
|
|
return self.model
|
|
|
|
|
|
def get_sub_model(self):
|
|
|
return self.sub_model
|
|
|
|
|
|
def count_span_length(self, span):
|
|
|
pairs = self.xy_pairs[span[0]:span[1]]
|
|
|
return sum(len(pair[0]) for pair in pairs), sum(len(pair[1]) for pair in pairs)
|
|
|
|
|
|
def align_span(self, x_span=None, y_span=None):
|
|
|
if x_span is None and y_span is None:
|
|
|
raise ValueError("Either x_span or y_span must be provided")
|
|
|
|
|
|
if x_span is not None and y_span is not None:
|
|
|
raise ValueError("Only one of x_span or y_span should be provided")
|
|
|
|
|
|
is_x = x_span is not None
|
|
|
z_span = x_span if is_x else y_span
|
|
|
cumsum_z = np.cumsum([0] + [len(pair[0 if is_x else 1]) for pair in self.xy_pairs]).tolist()
|
|
|
|
|
|
l, r = z_span
|
|
|
start_chunk = bisect.bisect_right(cumsum_z, l) - 1
|
|
|
end_chunk = bisect.bisect_left(cumsum_z, r)
|
|
|
|
|
|
aligned_l = cumsum_z[start_chunk]
|
|
|
aligned_r = cumsum_z[end_chunk]
|
|
|
|
|
|
aligned_span = (aligned_l, aligned_r)
|
|
|
pair_span = (start_chunk, end_chunk)
|
|
|
|
|
|
|
|
|
assert aligned_l <= l < aligned_r, "aligned_span does not properly contain the start of the input span"
|
|
|
assert aligned_l < r <= aligned_r, "aligned_span does not properly contain the end of the input span"
|
|
|
assert 0 <= start_chunk < end_chunk <= len(self.xy_pairs), "pair_span is out of bounds"
|
|
|
assert sum(len(pair[0 if is_x else 1]) for pair in self.xy_pairs[start_chunk:end_chunk]) == aligned_r - aligned_l, "aligned_span and pair_span do not match"
|
|
|
|
|
|
return aligned_span, pair_span
|
|
|
|
|
|
def get_chunk(self, pair_span=None, x_span=None, y_span=None, context_length=0, smooth=True):
|
|
|
if sum(x is not None for x in [pair_span, x_span, y_span]) != 1:
|
|
|
raise ValueError("Exactly one of pair_span, x_span, or y_span must be provided")
|
|
|
|
|
|
assert pair_span is None or (pair_span[0] >= 0 and pair_span[1] <= len(self.xy_pairs)), "pair_span is out of bounds"
|
|
|
|
|
|
is_x = x_span is not None
|
|
|
is_pair = pair_span is not None
|
|
|
|
|
|
if is_pair:
|
|
|
context_pair_span = (
|
|
|
max(0, pair_span[0] - context_length),
|
|
|
min(len(self.xy_pairs), pair_span[1] + context_length)
|
|
|
)
|
|
|
else:
|
|
|
assert smooth, "smooth must be True"
|
|
|
span = x_span if is_x else y_span
|
|
|
if smooth:
|
|
|
span, pair_span = self.align_span(x_span=span if is_x else None, y_span=span if not is_x else None)
|
|
|
|
|
|
context_span = (
|
|
|
max(0, span[0] - context_length),
|
|
|
min(self.x_len if is_x else self.y_len, span[1] + context_length)
|
|
|
)
|
|
|
|
|
|
context_span, context_pair_span = self.align_span(x_span=context_span if is_x else None, y_span=context_span if not is_x else None)
|
|
|
|
|
|
chunk_pairs = self.xy_pairs[context_pair_span[0]:context_pair_span[1]]
|
|
|
source_slice = context_pair_span
|
|
|
text_slice = (pair_span[0] - context_pair_span[0], pair_span[1] - context_pair_span[1])
|
|
|
assert text_slice[1] <= 0, "text_slice end must be negative"
|
|
|
text_slice = (text_slice[0], None if text_slice[1] == 0 else text_slice[1])
|
|
|
|
|
|
return Chunk(
|
|
|
chunk_pairs=chunk_pairs,
|
|
|
source_slice=source_slice,
|
|
|
text_slice=text_slice
|
|
|
)
|
|
|
|
|
|
def get_chunk_pair_span(self, chunk: Chunk):
|
|
|
pair_start, pair_end = chunk.text_source_slice.start, chunk.text_source_slice.stop
|
|
|
merged_x_chunk = ''.join(p[0] for p in self.xy_pairs[pair_start:pair_end])
|
|
|
merged_y_chunk = ''.join(p[1] for p in self.xy_pairs[pair_start:pair_end])
|
|
|
if merged_x_chunk == chunk.x_chunk and merged_y_chunk == chunk.y_chunk:
|
|
|
return pair_start, pair_end
|
|
|
|
|
|
pair_start, pair_end = 0, len(self.xy_pairs)
|
|
|
x_chunk, y_chunk = chunk.x_chunk, chunk.y_chunk
|
|
|
for i, (x, y) in enumerate(self.xy_pairs):
|
|
|
if x_chunk[:50].startswith(x[:50]) and y_chunk[:50].startswith(y[:50]):
|
|
|
pair_start = i
|
|
|
break
|
|
|
|
|
|
for i in range(pair_start, len(self.xy_pairs)):
|
|
|
x, y = self.xy_pairs[i]
|
|
|
if x_chunk[-50:].endswith(x[-50:]) and y_chunk[-50:].endswith(y[-50:]):
|
|
|
pair_end = i + 1
|
|
|
break
|
|
|
|
|
|
|
|
|
merged_x_chunk = ''.join(p[0] for p in self.xy_pairs[pair_start:pair_end])
|
|
|
merged_y_chunk = ''.join(p[1] for p in self.xy_pairs[pair_start:pair_end])
|
|
|
assert x_chunk == merged_x_chunk and y_chunk == merged_y_chunk, "Chunk mismatch"
|
|
|
|
|
|
return (pair_start, pair_end)
|
|
|
|
|
|
def apply_chunks(self, chunks: list[Chunk], new_chunks: list[Chunk]):
|
|
|
occupied_pair_span = [False] * len(self.xy_pairs)
|
|
|
pair_span_list = [self.get_chunk_pair_span(e) for e in chunks]
|
|
|
for pair_span in pair_span_list:
|
|
|
assert not any(occupied_pair_span[pair_span[0]:pair_span[1]]), "Chunk overlap"
|
|
|
occupied_pair_span[pair_span[0]:pair_span[1]] = [True] * (pair_span[1] - pair_span[0])
|
|
|
|
|
|
new_pairs_list = [e.text_pairs for e in new_chunks]
|
|
|
|
|
|
sorted_spans_with_new_pairs = sorted(
|
|
|
zip(pair_span_list, new_pairs_list),
|
|
|
key=lambda x: x[0][0],
|
|
|
reverse=True
|
|
|
)
|
|
|
|
|
|
for (start, end), new_pairs in sorted_spans_with_new_pairs:
|
|
|
self.xy_pairs[start:end] = new_pairs
|
|
|
|
|
|
def get_chunks(self, pair_span=None, chunk_length_ratio=1, context_length_ratio=1, offset_ratio=0):
|
|
|
pair_span = pair_span or (0, len(self.xy_pairs))
|
|
|
chunk_length = self.x_chunk_length * chunk_length_ratio, self.y_chunk_length * chunk_length_ratio
|
|
|
context_length = self.x_chunk_length//2 * context_length_ratio, self.y_chunk_length//2 * context_length_ratio
|
|
|
|
|
|
if 0 < offset_ratio < 1:
|
|
|
offset_ratio = int(chunk_length[0] * offset_ratio), int(chunk_length[1] * offset_ratio)
|
|
|
|
|
|
|
|
|
chunks = []
|
|
|
start = pair_span[0]
|
|
|
cstart = self.count_span_length((0, start))
|
|
|
max_cend = self.count_span_length((0, pair_span[1]))
|
|
|
while start < pair_span[1]:
|
|
|
if offset_ratio != 0:
|
|
|
cend = cstart[0] + offset_ratio[0], cstart[1] + offset_ratio[1]
|
|
|
offset_ratio = 0
|
|
|
else:
|
|
|
cend = cstart[0] + int(chunk_length[0] * 0.8), cstart[1] + int(chunk_length[1] * 0.8)
|
|
|
cend = min(cend[0], max_cend[0]), min(cend[1], max_cend[1])
|
|
|
|
|
|
|
|
|
x_len, y_len = cend[0] - cstart[0], cend[1] - cstart[1]
|
|
|
if x_len > 0:
|
|
|
chunk1 = self.get_chunk(x_span=(cstart[0], cend[0]), context_length=context_length[0])
|
|
|
if y_len > 0:
|
|
|
chunk2 = self.get_chunk(y_span=(cstart[1], cend[1]), context_length=context_length[1])
|
|
|
|
|
|
if x_len > 0 and y_len == 0:
|
|
|
chunk = chunk1
|
|
|
elif x_len == 0 and y_len > 0:
|
|
|
chunk = chunk2
|
|
|
elif x_len > 0 and y_len > 0:
|
|
|
|
|
|
chunk = chunk1 if chunk1.source_slice.stop - chunk1.source_slice.start < chunk2.source_slice.stop - chunk2.source_slice.start else chunk2
|
|
|
else:
|
|
|
raise ValueError("Both x_span and y_span have zero length")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
start = chunk.text_source_slice.stop
|
|
|
cstart = self.count_span_length((0, start))
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
def batch_yield(self, generators, chunks, prompt_name=None):
|
|
|
|
|
|
|
|
|
|
|
|
results = [None] * len(generators)
|
|
|
yields = [None] * len(generators)
|
|
|
finished = [False] * len(generators)
|
|
|
first_iter_flag = True
|
|
|
while True:
|
|
|
co_num = 0
|
|
|
for i, gen in enumerate(generators):
|
|
|
if finished[i]:
|
|
|
continue
|
|
|
|
|
|
try:
|
|
|
co_num += 1
|
|
|
yield_value = next(gen)
|
|
|
yields[i] = (yield_value, chunks[i])
|
|
|
except StopIteration as e:
|
|
|
results[i] = e.value
|
|
|
finished[i] = True
|
|
|
if yields[i] is None: yields[i] = (None, chunks[i])
|
|
|
|
|
|
if co_num >= self.max_thread_num:
|
|
|
break
|
|
|
|
|
|
if all(finished):
|
|
|
break
|
|
|
|
|
|
if first_iter_flag and prompt_name is not None:
|
|
|
yield (kp_msg := KeyPointMsg(prompt_name=prompt_name))
|
|
|
first_iter_flag = False
|
|
|
|
|
|
yield yields
|
|
|
|
|
|
if not first_iter_flag and prompt_name is not None:
|
|
|
yield kp_msg.set_finished()
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
def diff_to(self, cur, pair_span=None):
|
|
|
if pair_span is None:
|
|
|
pair_span = (0, len(self.xy_pairs))
|
|
|
|
|
|
if self.count_span_length(pair_span)[0] == 0:
|
|
|
|
|
|
pair_span2 = (0 + pair_span[0], len(cur.xy_pairs) - (len(self.xy_pairs) - pair_span[1]))
|
|
|
y_list = [e[1] for e in self.xy_pairs[pair_span[0]:pair_span[1]]]
|
|
|
y2_list =[e[1] for e in cur.xy_pairs[pair_span2[0]:pair_span2[1]]]
|
|
|
|
|
|
y_list += ['',] * max(len(y2_list) - len(y_list), 0)
|
|
|
y2_list += ['',] * max(len(y_list) - len(y2_list), 0)
|
|
|
|
|
|
data_chunks = [('', y, y2) for y, y2 in zip(y_list, y2_list)]
|
|
|
|
|
|
return data_chunks
|
|
|
|
|
|
pre_pointer = 0, 1
|
|
|
cur_pointer = 0, 1
|
|
|
|
|
|
cum_sum_pre = np.cumsum([0] + [len(pair[0]) for pair in self.xy_pairs])
|
|
|
cum_sum_cur = np.cumsum([0] + [len(pair[0]) for pair in cur.xy_pairs])
|
|
|
|
|
|
apply_chunks = []
|
|
|
|
|
|
while pre_pointer[1] <= len(self.xy_pairs) and cur_pointer[1] <= len(cur.xy_pairs):
|
|
|
if cum_sum_pre[pre_pointer[1]] - cum_sum_pre[pre_pointer[0]] == cum_sum_cur[cur_pointer[1]] - cum_sum_cur[cur_pointer[0]]:
|
|
|
chunk = self.get_chunk(pair_span=pre_pointer)
|
|
|
value = "".join(pair[1] for pair in cur.xy_pairs[cur_pointer[0]:cur_pointer[1]])
|
|
|
apply_chunks.append((chunk, 'y_chunk', value))
|
|
|
|
|
|
pre_pointer = pre_pointer[1], pre_pointer[1] + 1
|
|
|
cur_pointer = cur_pointer[1], cur_pointer[1] + 1
|
|
|
elif cum_sum_pre[pre_pointer[1]] - cum_sum_pre[pre_pointer[0]] < cum_sum_cur[cur_pointer[1]] - cum_sum_cur[cur_pointer[0]]:
|
|
|
pre_pointer = pre_pointer[0], pre_pointer[1] + 1
|
|
|
else:
|
|
|
cur_pointer = cur_pointer[0], cur_pointer[1] + 1
|
|
|
|
|
|
assert pre_pointer[1] == len(self.xy_pairs) + 1 and cur_pointer[1] == len(cur.xy_pairs) + 1
|
|
|
|
|
|
filtered_apply_chunks = []
|
|
|
for e in apply_chunks:
|
|
|
text_source_slice = e[0].text_source_slice
|
|
|
if text_source_slice.start >= pair_span[0] and text_source_slice.stop <= pair_span[1]:
|
|
|
filtered_apply_chunks.append(e)
|
|
|
|
|
|
data_chunks = []
|
|
|
for chunk, key, value in filtered_apply_chunks:
|
|
|
data_chunks.append((chunk.x_chunk, chunk.y_chunk, value))
|
|
|
|
|
|
return data_chunks
|
|
|
|
|
|
|
|
|
def apply_chunk(self, chunk:Chunk, key, value):
|
|
|
if not isinstance(chunk, Chunk):
|
|
|
chunk = Chunk(**chunk)
|
|
|
new_chunk = chunk.edit(**{key: value})
|
|
|
self.apply_chunks([chunk], [new_chunk])
|
|
|
|
|
|
def write_text(self, chunk:Chunk, prompt_main, user_prompt_text, input_keys=None, model=None):
|
|
|
chunk2prompt_key = {
|
|
|
'x_chunk': 'x',
|
|
|
'y_chunk': 'y',
|
|
|
'x_chunk_context': 'context_x',
|
|
|
'y_chunk_context': 'context_y'
|
|
|
}
|
|
|
|
|
|
|
|
|
if input_keys is not None:
|
|
|
prompt_kwargs = {k: getattr(chunk, k) for k in input_keys}
|
|
|
assert all(prompt_kwargs.values()), "Missing required context keys"
|
|
|
else:
|
|
|
prompt_kwargs = {k: getattr(chunk, k) for k in chunk2prompt_key.keys()}
|
|
|
|
|
|
prompt_kwargs = {chunk2prompt_key.get(k, k): v for k, v in prompt_kwargs.items()}
|
|
|
|
|
|
prompt_kwargs.update(self.global_context)
|
|
|
|
|
|
result = yield from prompt_main(
|
|
|
model=model or self.get_model(),
|
|
|
user_prompt=user_prompt_text,
|
|
|
**prompt_kwargs
|
|
|
)
|
|
|
|
|
|
|
|
|
update_dict = {}
|
|
|
if 'text_key' in result:
|
|
|
update_dict[result['text_key']] = result['text']
|
|
|
else:
|
|
|
update_dict['y_chunk'] = result['text']
|
|
|
|
|
|
return chunk.edit(**update_dict)
|
|
|
|
|
|
|
|
|
def review_text(self, chunk:Chunk, prompt_name, model=None):
|
|
|
result = yield from prompt_review(
|
|
|
model=model or self.get_model(),
|
|
|
prompt_name=prompt_name,
|
|
|
y=chunk.y_chunk
|
|
|
)
|
|
|
|
|
|
return result['text']
|
|
|
|
|
|
def map_text_wo_llm(self, chunk:Chunk):
|
|
|
|
|
|
new_xy_pairs = []
|
|
|
for x, y in chunk.text_pairs:
|
|
|
if x.strip() and not y.strip():
|
|
|
x_pairs = split_text_into_chunks(x, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5)
|
|
|
new_xy_pairs.extend([(x_pair, y) for x_pair in x_pairs])
|
|
|
elif not x.strip() and y.strip():
|
|
|
y_pairs = split_text_into_chunks(y, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5)
|
|
|
new_xy_pairs.extend([(x, y_pair) for y_pair in y_pairs])
|
|
|
else:
|
|
|
if len(x) > self.x_chunk_length or len(y) > self.y_chunk_length:
|
|
|
raise ValueError("窗口太小或段落太长!考虑选择更大的窗口长度或手动分段。")
|
|
|
new_xy_pairs.append((x, y))
|
|
|
|
|
|
return chunk.edit(text_pairs=new_xy_pairs)
|
|
|
|
|
|
def map_text(self, chunk:Chunk):
|
|
|
|
|
|
|
|
|
if chunk.x_chunk.strip():
|
|
|
x_pairs = split_text_into_chunks(chunk.x_chunk, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
|
|
|
assert len(x_pairs) >= len(chunk.text_pairs), "未知错误!合并所有区块后再分区块,结果更少?"
|
|
|
if len(x_pairs) == len(chunk.text_pairs):
|
|
|
return chunk, True, ''
|
|
|
else:
|
|
|
|
|
|
y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
|
|
|
new_xy_pairs = [('', y) for y in y_pairs]
|
|
|
return chunk.edit(text_pairs=new_xy_pairs), True, ''
|
|
|
|
|
|
try:
|
|
|
y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=len(x_pairs), min_chunk_size=5, max_chunk_n=20)
|
|
|
except Exception as e:
|
|
|
|
|
|
y_pairs = split_text_into_chunks(chunk.y_chunk, self.y_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=20)
|
|
|
x_pairs = split_text_into_chunks(chunk.x_chunk, self.x_chunk_length, min_chunk_n=1, min_chunk_size=5, max_chunk_n=int(0.8 * len(y_pairs)))
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
gen = match_plot_and_text.main(
|
|
|
model=self.get_sub_model(),
|
|
|
plot_chunks=x_pairs,
|
|
|
text_chunks=y_pairs
|
|
|
)
|
|
|
while True:
|
|
|
yield next(gen)
|
|
|
except StopIteration as e:
|
|
|
output = e.value
|
|
|
|
|
|
x2y = output['plot2text']
|
|
|
new_xy_pairs = []
|
|
|
for xi_list, yi_list in x2y:
|
|
|
xl, xr = xi_list[0], xi_list[-1]
|
|
|
new_xy_pairs.append(("".join(x_pairs[xl:xr+1]), "".join(y_pairs[i] for i in yi_list)))
|
|
|
|
|
|
new_chunk = chunk.edit(text_pairs=new_xy_pairs)
|
|
|
return new_chunk, True, ''
|
|
|
|
|
|
def batch_map_text(self, chunks):
|
|
|
results = yield from self.batch_yield(
|
|
|
[self.map_text(e) for e in chunks], chunks, prompt_name='映射文本')
|
|
|
return results
|
|
|
|
|
|
def batch_write_apply_text(self, chunks, prompt_main, user_prompt_text):
|
|
|
new_chunks = yield from self.batch_yield(
|
|
|
[self.write_text(e, prompt_main, user_prompt_text) for e in chunks],
|
|
|
chunks, prompt_name='创作文本')
|
|
|
|
|
|
results = yield from self.batch_map_text(new_chunks)
|
|
|
new_chunks2 = [e[0] for e in results]
|
|
|
|
|
|
self.apply_chunks(chunks, new_chunks2)
|
|
|
|
|
|
def batch_review_write_apply_text(self, chunks, write_prompt_main, review_prompt_name):
|
|
|
reviews = yield from self.batch_yield(
|
|
|
[self.review_text(e, review_prompt_name) for e in chunks],
|
|
|
chunks, prompt_name='审阅文本')
|
|
|
|
|
|
rewrite_instrustion = "\n\n根据审阅意见,重新创作,如果审阅意见表示无需改动,则保持原样输出。"
|
|
|
|
|
|
new_chunks = yield from self.batch_yield(
|
|
|
[self.write_text(chunk, write_prompt_main, review + rewrite_instrustion) for chunk, review in zip(chunks, reviews)],
|
|
|
chunks, prompt_name='创作文本')
|
|
|
|
|
|
results = yield from self.batch_map_text(new_chunks)
|
|
|
new_chunks2 = [e[0] for e in results]
|
|
|
|
|
|
self.apply_chunks(chunks, new_chunks2)
|
|
|
|