|
|
import sys |
|
|
import torch |
|
|
import torch.utils._triton |
|
|
|
|
|
|
|
|
def fake_is_available(): |
|
|
return True |
|
|
def fake_device_capability(*args, **kwargs): |
|
|
return (8, 0) |
|
|
def fake_current_device(): |
|
|
return 0 |
|
|
def fake_device_count(): |
|
|
return 1 |
|
|
def has_triton(): |
|
|
return False |
|
|
def get_fake_stream(*args, **kwargs): |
|
|
return 0 |
|
|
|
|
|
sys.modules["torch"].cuda.is_available = fake_is_available |
|
|
sys.modules["torch"].cuda.get_device_capability = fake_device_capability |
|
|
sys.modules["torch"].cuda.current_device = fake_current_device |
|
|
sys.modules["torch"].cuda.device_count = fake_device_count |
|
|
sys.modules["torch.utils._triton"].has_triton = has_triton |
|
|
sys.modules["torch._C"]._cuda_getCurrentRawStream = get_fake_stream |
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
from unsloth.chat_templates import get_chat_template, CHAT_TEMPLATES |
|
|
from unsloth_zoo.dataset_utils import train_on_responses_only |
|
|
from transformers import AutoProcessor |
|
|
class DummyArgs: |
|
|
pass |
|
|
|
|
|
class DummyDataset: |
|
|
def __init__(self, example): |
|
|
self.example = [example] |
|
|
|
|
|
def map(self, function, *args, **kwargs): |
|
|
self.example[0].update(function(self.example[0])) |
|
|
return self |
|
|
|
|
|
def __len__(self): |
|
|
return 1 |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.example[idx] |
|
|
|
|
|
class DummyTrainer: |
|
|
pass |
|
|
|
|
|
st.title('Train With Response Only Analyzer') |
|
|
|
|
|
_, col, _ = st.columns([1, 1, 1]) |
|
|
col.image("https://raw.githubusercontent.com/unslothai/unsloth/main/images/made%20with%20unsloth.png", width=200) |
|
|
|
|
|
model = st.text_input("Enter HuggingFace model name", st.query_params.get("model", "Qwen/Qwen2-VL-7B-Instruct")) |
|
|
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True) |
|
|
text_tokenizer = processor if not hasattr(processor, "tokenizer") else processor.tokenizer |
|
|
chat_template_predefined = st.query_params.get("chat_template_idx", None) |
|
|
possible_templates = ["model_default"] + sorted(CHAT_TEMPLATES.keys()) |
|
|
if chat_template_predefined is not None: |
|
|
chat_template_idx = possible_templates.index(chat_template_predefined) |
|
|
else: |
|
|
chat_template_idx = 0 |
|
|
|
|
|
chat_template_key = st.selectbox("Select chat template", possible_templates, index=chat_template_idx) |
|
|
|
|
|
if chat_template_key == "model_default": |
|
|
chat_template = None |
|
|
else: |
|
|
chat_template = CHAT_TEMPLATES.get(chat_template_key)[0] |
|
|
|
|
|
if chat_template is None: |
|
|
chat_template = text_tokenizer.chat_template |
|
|
if chat_template is None: |
|
|
|
|
|
import warnings |
|
|
st.warning("Chat template not found in the tokenizer. Not using any chat template.") |
|
|
|
|
|
with st.expander("Click to see the chat template"): |
|
|
st.markdown("#### Chat Template (in Jinja2 format)") |
|
|
st.code(chat_template, language="jinja2") |
|
|
|
|
|
sample = {"conversations": [{'content': 'Do you like Unsloth?', 'role': 'user'}, {'content': 'Yes', 'role': 'assistant'}, {'content': 'Will you star them on GitHub?', 'role': 'user'}, {'content': 'Sure!', 'role': 'assistant'}]} |
|
|
|
|
|
message_sample = sample.get("conversations", "") |
|
|
message = st.text_area("Enter your message here", st.query_params.get("message", str(message_sample))) |
|
|
|
|
|
try: |
|
|
message = eval(message) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if chat_template is not None: |
|
|
converted_message = text_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False, chat_template=chat_template) |
|
|
else: |
|
|
converted_message = message |
|
|
|
|
|
st.markdown("#### Original Message") |
|
|
st.code(converted_message, language="html") |
|
|
|
|
|
instruction_part = st.text_input("Enter instruction Part", st.query_params.get("instruction_part", "<|im_start|>user")) |
|
|
response_part = st.text_input("Enter response Part", st.query_params.get("response_part", "<|im_start|>assistant")) |
|
|
|
|
|
trainer = DummyTrainer() |
|
|
trainer.train_dataset = DummyDataset({"input_ids": [text_tokenizer.encode(converted_message)]}) |
|
|
trainer.tokenizer = text_tokenizer |
|
|
trainer.args = DummyArgs() |
|
|
trainer.args.dataset_kwargs = {"skip_prepare_dataset": False} |
|
|
|
|
|
trainer = train_on_responses_only(trainer, instruction_part, response_part) |
|
|
ids = trainer.train_dataset[0]["labels"][0] |
|
|
mask = text_tokenizer.encode("x", add_special_tokens = False)[0] |
|
|
masked_text = text_tokenizer.decode([mask if x == -100 else x for x in ids]) |
|
|
|
|
|
st.markdown("#### Masked Prompt ('x' is the mask token)") |
|
|
st.code(masked_text, language="html") |
|
|
|
|
|
st.markdown("#### Your Unsloth code snippet") |
|
|
code = f"""from unsloth.chat_templates import train_on_responses_only |
|
|
trainer = train_on_responses_only( |
|
|
trainer, |
|
|
instruction_part = "{instruction_part}", |
|
|
response_part = "{response_part}", |
|
|
) |
|
|
""" |
|
|
st.code(code, language="python") |
|
|
|
|
|
st.markdown("#### You may share the following URL with others to show them the results") |
|
|
|
|
|
url = "https://zeel-twro.hf.space" |
|
|
params = { |
|
|
"model": model, |
|
|
"message": message, |
|
|
"instruction_part": instruction_part, |
|
|
"response_part": response_part, |
|
|
"chat_template_idx": chat_template_key, |
|
|
} |
|
|
import urllib.parse |
|
|
url = url + "?" + urllib.parse.urlencode(params) |
|
|
st.markdown(f"`{url}`") |