File size: 5,253 Bytes
42c11d3 7145796 6cb37fc 769d8b8 fcbbb53 42c11d3 6cb37fc 42c11d3 189e1de 6cb37fc 769d8b8 42c11d3 189e1de 6cb37fc fcbbb53 7145796 68f1499 e685da5 68f1499 e685da5 8b720fd e685da5 6ac728f e685da5 6ac728f e685da5 71604ee fcbbb53 f429493 fcbbb53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import sys
import torch
import torch.utils._triton
# >>> Hack unsloth to work without GPU
def fake_is_available():
return True
def fake_device_capability(*args, **kwargs):
return (8, 0)
def fake_current_device():
return 0
def fake_device_count():
return 1
def has_triton():
return False
def get_fake_stream(*args, **kwargs):
return 0
sys.modules["torch"].cuda.is_available = fake_is_available
sys.modules["torch"].cuda.get_device_capability = fake_device_capability
sys.modules["torch"].cuda.current_device = fake_current_device
sys.modules["torch"].cuda.device_count = fake_device_count
sys.modules["torch.utils._triton"].has_triton = has_triton
sys.modules["torch._C"]._cuda_getCurrentRawStream = get_fake_stream
# <<<
import streamlit as st
from unsloth.chat_templates import get_chat_template, CHAT_TEMPLATES
from unsloth_zoo.dataset_utils import train_on_responses_only
from transformers import AutoProcessor
class DummyArgs:
pass
class DummyDataset:
def __init__(self, example):
self.example = [example]
def map(self, function, *args, **kwargs):
self.example[0].update(function(self.example[0]))
return self
def __len__(self):
return 1
def __getitem__(self, idx):
return self.example[idx]
class DummyTrainer:
pass
st.title('Train With Response Only Analyzer')
# add an image
_, col, _ = st.columns([1, 1, 1])
col.image("https://raw.githubusercontent.com/unslothai/unsloth/main/images/made%20with%20unsloth.png", width=200)
model = st.text_input("Enter HuggingFace model name", st.query_params.get("model", "Qwen/Qwen2-VL-7B-Instruct"))
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)
text_tokenizer = processor if not hasattr(processor, "tokenizer") else processor.tokenizer
chat_template_predefined = st.query_params.get("chat_template_idx", None)
possible_templates = ["model_default"] + sorted(CHAT_TEMPLATES.keys())
if chat_template_predefined is not None:
chat_template_idx = possible_templates.index(chat_template_predefined)
else:
chat_template_idx = 0
chat_template_key = st.selectbox("Select chat template", possible_templates, index=chat_template_idx)
if chat_template_key == "model_default":
chat_template = None
else:
chat_template = CHAT_TEMPLATES.get(chat_template_key)[0]
if chat_template is None:
chat_template = text_tokenizer.chat_template
if chat_template is None:
# raise ValueError("Chat template not found in the tokenizer. Please select a chat template from the dropdown.")
import warnings
st.warning("Chat template not found in the tokenizer. Not using any chat template.")
with st.expander("Click to see the chat template"):
st.markdown("#### Chat Template (in Jinja2 format)")
st.code(chat_template, language="jinja2")
sample = {"conversations": [{'content': 'Do you like Unsloth?', 'role': 'user'}, {'content': 'Yes', 'role': 'assistant'}, {'content': 'Will you star them on GitHub?', 'role': 'user'}, {'content': 'Sure!', 'role': 'assistant'}]}
message_sample = sample.get("conversations", "")
message = st.text_area("Enter your message here", st.query_params.get("message", str(message_sample)))
try:
message = eval(message)
except:
pass
if chat_template is not None:
converted_message = text_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False, chat_template=chat_template)
else:
converted_message = message
st.markdown("#### Original Message")
st.code(converted_message, language="html")
instruction_part = st.text_input("Enter instruction Part", st.query_params.get("instruction_part", "<|im_start|>user"))
response_part = st.text_input("Enter response Part", st.query_params.get("response_part", "<|im_start|>assistant"))
trainer = DummyTrainer()
trainer.train_dataset = DummyDataset({"input_ids": [text_tokenizer.encode(converted_message)]})
trainer.tokenizer = text_tokenizer
trainer.args = DummyArgs()
trainer.args.dataset_kwargs = {"skip_prepare_dataset": False}
trainer = train_on_responses_only(trainer, instruction_part, response_part)
ids = trainer.train_dataset[0]["labels"][0]
mask = text_tokenizer.encode("x", add_special_tokens = False)[0]
masked_text = text_tokenizer.decode([mask if x == -100 else x for x in ids])
st.markdown("#### Masked Prompt ('x' is the mask token)")
st.code(masked_text, language="html")
st.markdown("#### Your Unsloth code snippet")
code = f"""from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
instruction_part = "{instruction_part}",
response_part = "{response_part}",
)
"""
st.code(code, language="python")
st.markdown("#### You may share the following URL with others to show them the results")
# url = f"https://zeel-twro.hf.space?model={model}&message={message}&instruction_part={instruction_part}&response_part={response_part}&chat_template_idx={chat_template_key}"
url = "https://zeel-twro.hf.space"
params = {
"model": model,
"message": message,
"instruction_part": instruction_part,
"response_part": response_part,
"chat_template_idx": chat_template_key,
}
import urllib.parse
url = url + "?" + urllib.parse.urlencode(params)
st.markdown(f"`{url}`") |