nifty-lab / data /system_prompt_render.py
IgorCSIS
Initial Space deploy, ZeroGPU adapter for Lance
e18ede0
Raw
History Blame Contribute Delete
7.27 kB
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from jinja2 import Environment, BaseLoader
JINJA_PROMPT_TMPL = (
"<|im_start|>system\n"
"{{ system_prompt }}<|im_end|>\n"
"{% for m in msgs -%}"
"<|im_start|>{{ m.role }}\n"
"{% if not (m.role == 'assistant' and not include_assistant_content) -%}"
"{{ m.content | render_mm_list }}"
"{% endif -%}"
"{% if (not (loop.last and m.role == 'assistant')) or include_assistant_content -%}"
"<|im_end|>\n"
"{% endif -%}"
"{% endfor -%}"
)
VS, VE = "<|vision_start|>", "<|vision_end|>"
VP, IP = "<|video_pad|>", "<|image_pad|>"
def expand_and_index_by_token_ids_new(
rendered_text: str,
tokens: List[int],
tokenizer,
target_text: str = "",
search_text: str = "",
) -> Tuple[str, List[int], List[List[int]], List[int]]:
"""
Returns:
new_rendered_text: expanded text
all_token_id : token ids of new_rendered_text
spans_index : indexes of each pad block in all_token_id, in occurrence order, e.g. [[100..199], [350..549], ...]
tgt_index : indexes of target_text in all_token_id, or [] if not found
"""
vs_ids = tokenizer(VS, add_special_tokens=False)["input_ids"]
ve_ids = tokenizer(VE, add_special_tokens=False)["input_ids"]
vp_ids = tokenizer(VP, add_special_tokens=False)["input_ids"]
ip_ids = tokenizer(IP, add_special_tokens=False)["input_ids"]
enc = tokenizer(rendered_text, add_special_tokens=False)
base_ids = enc["input_ids"]
# ---------- 1) Scan VP/IP in occurrence order and expand each to K copies, recording pad block metadata ----------
# find all VS positions and pair them with nearest VE after each VS
all_ids: List[int] = []
spans_index: List[List[int]] = []
i = 0 # Scan pointer for base_ids.
tk_ptr = 0 # Pointer for tokens(K).
while True:
try:
vs_positions_ = base_ids[i:].index(vs_ids[0]) + i
except:
all_ids.extend(base_ids[i:])
break
all_ids.extend(base_ids[i: vs_positions_])
i = vs_positions_ + 3
# Expand the sequence and insert placeholder ids into pad_ids.
pad_ids = base_ids[vs_positions_ + 1:vs_positions_ + 2]
K = int(tokens[tk_ptr])
start, end = len(all_ids) + 1, len(all_ids) + 1 + K
all_ids.extend(vs_ids + pad_ids * K + ve_ids)
tk_ptr += 1
# Collect indexes of each pad token block in all_token_id, in occurrence order, e.g. [[100..199], [350..549], ...].
#start, end = vs_positions_ + 1, vs_positions_ + 1 + K
spans_index.append(list(range(start, end)))
tgt_index: List[int] = []
if target_text:
tgt_ids_identify = tokenizer(target_text, add_special_tokens=False)["input_ids"]
i = 0 # Scan pointer for base_ids.
while i < len(all_ids):
tgt_positions_ = all_ids[i:].index(tgt_ids_identify[0]) + i
if all_ids[tgt_positions_+len(tgt_ids_identify)-1] == tgt_ids_identify[-1]:
tgt_index = list(range(tgt_positions_+len(tgt_ids_identify), len(all_ids)))
break
else:
i = tgt_positions_ + 1
search_index: List[int] = []
if search_text:
search_ids_identify = tokenizer(search_text, add_special_tokens=False)["input_ids"]
i = 0 # Scan pointer for base_ids.
while i < len(all_ids):
search_positions_ = all_ids[i:].index(search_ids_identify[0]) + i
if all_ids[search_positions_:search_positions_+len(search_ids_identify)] == search_ids_identify:
search_index = list(range(search_positions_, search_positions_+len(search_ids_identify)))
break
else:
i = search_positions_ + 1
return all_ids, spans_index, tgt_index, search_index
def _extract_system_prompt(messages: List[Dict[str, Any]], default_system: str) -> str:
for m in messages:
if m.get("role") == "system":
c = m.get("content", "")
if isinstance(c, str):
return c
if isinstance(c, list):
texts = [it.get("text", "") for it in c if isinstance(it, dict) and it.get("type") == "text"]
if texts:
return "".join(texts)
return default_system
def _normalize_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
norm: List[Dict[str, Any]] = []
for m in messages:
role = m.get("role")
if role == "system":
continue
c = m.get("content", "")
if isinstance(c, str):
items = [{"type": "text", "text": c}]
elif isinstance(c, list):
items = c
else:
items = []
norm.append({"role": role, "content": items})
return norm
def render_qwenvl_prompt(
messages: List[Dict[str, Any]],
default_system: str = "You are a helpful assistant.",
include_assistant_content: bool = False, # Key option: whether to render assistant text.
force_video_pad: bool = False,
) -> str:
system_prompt = _extract_system_prompt(messages, default_system)
msgs = _normalize_messages(messages)
def _render_mm_list(items: Any) -> str:
if isinstance(items, str):
return items
if not isinstance(items, list):
return ""
parts: List[str] = []
for it in items:
if not isinstance(it, dict):
continue
t = it.get("type")
if t == "text":
parts.append(it.get("text", ""))
elif t == "image":
if force_video_pad:
parts.append("<|vision_start|><|image_pad|><|vision_end|>")
else:
parts.append("<|vision_start|><|video_pad|><|vision_end|>")
elif t == "video":
parts.append("<|vision_start|><|video_pad|><|vision_end|>")
# Other modalities can be added here.
return "".join(parts)
env = Environment(
loader=BaseLoader(),
autoescape=False,
trim_blocks=True, # Remove newlines after block endings.
lstrip_blocks=True, # Remove whitespace before block starts.
newline_sequence="\n",
keep_trailing_newline=False,
)
env.filters["render_mm_list"] = _render_mm_list
template = env.from_string(JINJA_PROMPT_TMPL)
return template.render(
system_prompt=system_prompt,
msgs=msgs,
include_assistant_content=include_assistant_content,
)