TWRO / app.py
Zeel's picture
encode the URL correctly.
f429493
import sys
import torch
import torch.utils._triton
# >>> Hack unsloth to work without GPU
def fake_is_available():
return True
def fake_device_capability(*args, **kwargs):
return (8, 0)
def fake_current_device():
return 0
def fake_device_count():
return 1
def has_triton():
return False
def get_fake_stream(*args, **kwargs):
return 0
sys.modules["torch"].cuda.is_available = fake_is_available
sys.modules["torch"].cuda.get_device_capability = fake_device_capability
sys.modules["torch"].cuda.current_device = fake_current_device
sys.modules["torch"].cuda.device_count = fake_device_count
sys.modules["torch.utils._triton"].has_triton = has_triton
sys.modules["torch._C"]._cuda_getCurrentRawStream = get_fake_stream
# <<<
import streamlit as st
from unsloth.chat_templates import get_chat_template, CHAT_TEMPLATES
from unsloth_zoo.dataset_utils import train_on_responses_only
from transformers import AutoProcessor
class DummyArgs:
pass
class DummyDataset:
def __init__(self, example):
self.example = [example]
def map(self, function, *args, **kwargs):
self.example[0].update(function(self.example[0]))
return self
def __len__(self):
return 1
def __getitem__(self, idx):
return self.example[idx]
class DummyTrainer:
pass
st.title('Train With Response Only Analyzer')
# add an image
_, col, _ = st.columns([1, 1, 1])
col.image("https://raw.githubusercontent.com/unslothai/unsloth/main/images/made%20with%20unsloth.png", width=200)
model = st.text_input("Enter HuggingFace model name", st.query_params.get("model", "Qwen/Qwen2-VL-7B-Instruct"))
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)
text_tokenizer = processor if not hasattr(processor, "tokenizer") else processor.tokenizer
chat_template_predefined = st.query_params.get("chat_template_idx", None)
possible_templates = ["model_default"] + sorted(CHAT_TEMPLATES.keys())
if chat_template_predefined is not None:
chat_template_idx = possible_templates.index(chat_template_predefined)
else:
chat_template_idx = 0
chat_template_key = st.selectbox("Select chat template", possible_templates, index=chat_template_idx)
if chat_template_key == "model_default":
chat_template = None
else:
chat_template = CHAT_TEMPLATES.get(chat_template_key)[0]
if chat_template is None:
chat_template = text_tokenizer.chat_template
if chat_template is None:
# raise ValueError("Chat template not found in the tokenizer. Please select a chat template from the dropdown.")
import warnings
st.warning("Chat template not found in the tokenizer. Not using any chat template.")
with st.expander("Click to see the chat template"):
st.markdown("#### Chat Template (in Jinja2 format)")
st.code(chat_template, language="jinja2")
sample = {"conversations": [{'content': 'Do you like Unsloth?', 'role': 'user'}, {'content': 'Yes', 'role': 'assistant'}, {'content': 'Will you star them on GitHub?', 'role': 'user'}, {'content': 'Sure!', 'role': 'assistant'}]}
message_sample = sample.get("conversations", "")
message = st.text_area("Enter your message here", st.query_params.get("message", str(message_sample)))
try:
message = eval(message)
except:
pass
if chat_template is not None:
converted_message = text_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False, chat_template=chat_template)
else:
converted_message = message
st.markdown("#### Original Message")
st.code(converted_message, language="html")
instruction_part = st.text_input("Enter instruction Part", st.query_params.get("instruction_part", "<|im_start|>user"))
response_part = st.text_input("Enter response Part", st.query_params.get("response_part", "<|im_start|>assistant"))
trainer = DummyTrainer()
trainer.train_dataset = DummyDataset({"input_ids": [text_tokenizer.encode(converted_message)]})
trainer.tokenizer = text_tokenizer
trainer.args = DummyArgs()
trainer.args.dataset_kwargs = {"skip_prepare_dataset": False}
trainer = train_on_responses_only(trainer, instruction_part, response_part)
ids = trainer.train_dataset[0]["labels"][0]
mask = text_tokenizer.encode("x", add_special_tokens = False)[0]
masked_text = text_tokenizer.decode([mask if x == -100 else x for x in ids])
st.markdown("#### Masked Prompt ('x' is the mask token)")
st.code(masked_text, language="html")
st.markdown("#### Your Unsloth code snippet")
code = f"""from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
instruction_part = "{instruction_part}",
response_part = "{response_part}",
)
"""
st.code(code, language="python")
st.markdown("#### You may share the following URL with others to show them the results")
# url = f"https://zeel-twro.hf.space?model={model}&message={message}&instruction_part={instruction_part}&response_part={response_part}&chat_template_idx={chat_template_key}"
url = "https://zeel-twro.hf.space"
params = {
"model": model,
"message": message,
"instruction_part": instruction_part,
"response_part": response_part,
"chat_template_idx": chat_template_key,
}
import urllib.parse
url = url + "?" + urllib.parse.urlencode(params)
st.markdown(f"`{url}`")