File size: 5,253 Bytes
42c11d3
7145796
6cb37fc
769d8b8
fcbbb53
42c11d3
 
6cb37fc
42c11d3
189e1de
 
6cb37fc
 
 
 
 
 
769d8b8
42c11d3
 
189e1de
6cb37fc
 
 
fcbbb53
7145796
68f1499
e685da5
 
68f1499
e685da5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b720fd
 
 
 
 
 
 
 
e685da5
 
 
 
 
 
 
 
 
6ac728f
 
 
e685da5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ac728f
 
 
 
e685da5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71604ee
 
 
 
 
 
 
 
 
 
fcbbb53
f429493
 
 
 
 
 
 
 
 
 
 
fcbbb53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import sys
import torch
import torch.utils._triton

# >>> Hack unsloth to work without GPU
def fake_is_available():
    return True
def fake_device_capability(*args, **kwargs):
    return (8, 0)
def fake_current_device():
    return 0
def fake_device_count():
    return 1
def has_triton():
    return False
def get_fake_stream(*args, **kwargs):
    return 0

sys.modules["torch"].cuda.is_available = fake_is_available
sys.modules["torch"].cuda.get_device_capability = fake_device_capability
sys.modules["torch"].cuda.current_device = fake_current_device
sys.modules["torch"].cuda.device_count = fake_device_count
sys.modules["torch.utils._triton"].has_triton = has_triton
sys.modules["torch._C"]._cuda_getCurrentRawStream = get_fake_stream
# <<<

import streamlit as st
from unsloth.chat_templates import get_chat_template, CHAT_TEMPLATES
from unsloth_zoo.dataset_utils import train_on_responses_only
from transformers import AutoProcessor
class DummyArgs:
    pass

class DummyDataset:
    def __init__(self, example):
        self.example = [example]

    def map(self, function, *args, **kwargs):
        self.example[0].update(function(self.example[0]))
        return self
    
    def __len__(self):
        return 1
    
    def __getitem__(self, idx):
        return self.example[idx]

class DummyTrainer:
    pass

st.title('Train With Response Only Analyzer')
# add an image
_, col, _ = st.columns([1, 1, 1])
col.image("https://raw.githubusercontent.com/unslothai/unsloth/main/images/made%20with%20unsloth.png", width=200)

model = st.text_input("Enter HuggingFace model name", st.query_params.get("model", "Qwen/Qwen2-VL-7B-Instruct"))
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)
text_tokenizer = processor if not hasattr(processor, "tokenizer") else processor.tokenizer
chat_template_predefined = st.query_params.get("chat_template_idx", None)
possible_templates = ["model_default"] + sorted(CHAT_TEMPLATES.keys())
if chat_template_predefined is not None:
    chat_template_idx = possible_templates.index(chat_template_predefined)
else:
    chat_template_idx = 0

chat_template_key = st.selectbox("Select chat template", possible_templates, index=chat_template_idx)

if chat_template_key == "model_default":
    chat_template = None
else:
    chat_template = CHAT_TEMPLATES.get(chat_template_key)[0]

if chat_template is None:
    chat_template = text_tokenizer.chat_template
if chat_template is None:
    # raise ValueError("Chat template not found in the tokenizer. Please select a chat template from the dropdown.")
    import warnings
    st.warning("Chat template not found in the tokenizer. Not using any chat template.")

with st.expander("Click to see the chat template"):
    st.markdown("#### Chat Template (in Jinja2 format)")
    st.code(chat_template, language="jinja2")

sample = {"conversations": [{'content': 'Do you like Unsloth?', 'role': 'user'}, {'content': 'Yes', 'role': 'assistant'}, {'content': 'Will you star them on GitHub?', 'role': 'user'}, {'content': 'Sure!', 'role': 'assistant'}]}

message_sample = sample.get("conversations", "")
message = st.text_area("Enter your message here", st.query_params.get("message", str(message_sample)))

try:
    message = eval(message)
except:
    pass

if chat_template is not None:
    converted_message = text_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False, chat_template=chat_template)
else:
    converted_message = message

st.markdown("#### Original Message")
st.code(converted_message, language="html")

instruction_part = st.text_input("Enter instruction Part", st.query_params.get("instruction_part", "<|im_start|>user"))
response_part = st.text_input("Enter response Part", st.query_params.get("response_part", "<|im_start|>assistant"))

trainer = DummyTrainer()
trainer.train_dataset = DummyDataset({"input_ids": [text_tokenizer.encode(converted_message)]})
trainer.tokenizer = text_tokenizer
trainer.args = DummyArgs()
trainer.args.dataset_kwargs = {"skip_prepare_dataset": False}

trainer = train_on_responses_only(trainer, instruction_part, response_part)
ids = trainer.train_dataset[0]["labels"][0]
mask = text_tokenizer.encode("x", add_special_tokens = False)[0]
masked_text = text_tokenizer.decode([mask if x == -100 else x for x in ids])

st.markdown("#### Masked Prompt ('x' is the mask token)")
st.code(masked_text, language="html")

st.markdown("#### Your Unsloth code snippet")
code = f"""from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "{instruction_part}",
    response_part = "{response_part}",
)
"""
st.code(code, language="python")

st.markdown("#### You may share the following URL with others to show them the results")
# url = f"https://zeel-twro.hf.space?model={model}&message={message}&instruction_part={instruction_part}&response_part={response_part}&chat_template_idx={chat_template_key}"
url = "https://zeel-twro.hf.space"
params = {
    "model": model,
    "message": message,
    "instruction_part": instruction_part,
    "response_part": response_part,
    "chat_template_idx": chat_template_key,
}
import urllib.parse
url = url + "?" + urllib.parse.urlencode(params)
st.markdown(f"`{url}`")