demo / run.py
elribonazo's picture
Upload folder using huggingface_hub
cea6cdd verified
from datasets import load_dataset, concatenate_datasets
from transformers import TrainingArguments, TextStreamer
from trl import SFTTrainer
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel, is_bfloat16_supported
# ###############################################################################
# # 1. Load/Initialize Model and Tokenizer
# ###############################################################################
# max_seq_length = 2048
# model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
# model, tokenizer = FastLanguageModel.from_pretrained(
# model_name=model_name,
# max_seq_length=max_seq_length,
# load_in_4bit=True,
# dtype=None,
# )
# model = FastLanguageModel.get_peft_model(
# model,
# r=16,
# lora_alpha=16,
# lora_dropout=0,
# target_modules=[
# "q_proj", "k_proj", "v_proj", "up_proj", "down_proj",
# "o_proj", "gate_proj"
# ],
# use_rslora=True,
# use_gradient_checkpointing="unsloth"
# )
# # Prepare the tokenizer for "chatml" format
# tokenizer = get_chat_template(
# tokenizer,
# mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"},
# chat_template="chatml",
# )
# ###############################################################################
# # 2. Dataset Loading and Caching
# ###############################################################################
# # The user’s custom function to apply chat template:
# def apply_template(examples):
# messages_batch = examples["conversations"]
# texts = []
# for message in messages_batch:
# text = tokenizer.apply_chat_template(
# message,
# tokenize=False,
# add_generation_prompt=False
# )
# texts.append(text)
# return {"text": texts}
# def apply_template2(examples):
# import json
# conversation_batch = examples["conversation"]
# tools_batch = examples["tools"]
# texts = []
# for i, conversation_json_str in enumerate(conversation_batch):
# # 1) Load conversation & tools:
# thread = json.loads(conversation_json_str)
# tools_data = json.loads(tools_batch[i])
# # 2) Convert "arguments" to "parameters"
# for tool in tools_data:
# if "arguments" in tool:
# tool["parameters"] = tool["arguments"]
# # 3) Create system prompt
# system_prompt = {
# "from": "system",
# "value": (
# "You are a function calling AI model. You are provided with "
# "function signatures within <tools> </tools> XML tags. Don't make "
# "assumptions about what values to plug into functions.\n"
# f"<tools>{json.dumps(tools_data)}</tools>"
# )
# }
# # 4) Build new conversation
# clean_thread = [system_prompt]
# for msg in thread:
# # Possibly rename "role": "tool call" to something else
# if msg["role"] == "tool call":
# msg["role"] = "gtp"
# # The code below ensures "value" is <tool_call> ... </tool_call>
# if not isinstance(msg, dict):
# # If it's not a dict, forcibly convert to dict
# item = json.dumps({"type":"function", "function": msg['content']})
# clean_thread.append({
# "from": msg["role"],
# "value": f"<tool_call>{item}</tool_call>"
# })
# else:
# item = json.dumps({"type":"function", "function": msg['content']})
# clean_thread.append({
# "from": msg["role"],
# "value": f"<tool_call>{item}</tool_call>"
# })
# # 6) PASS THE LIST (NOT the JSON string) to apply_chat_template
# text = tokenizer.apply_chat_template(
# clean_thread,
# tokenize=False,
# add_generation_prompt=False
# )
# texts.append(text)
# return {"text": texts}
# tool_intro = "You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions."
# # If you want a local cache file, specify cache_file_name
# dataset_1 = load_dataset(
# "interstellarninja/tool-calls-sharegpt",
# split="train",
# )
# # Load second dataset
# dataset_2 = load_dataset(
# "interstellarninja/tool-calls-multiturn",
# split="train",
# )
# dataset_3 = load_dataset(
# "BitAgent/tool_calling",
# split="train",
# )
# dataset_1 = dataset_1.map(apply_template, batched=True)
# dataset_2 = dataset_2.map(apply_template, batched=True)
# dataset_3 = dataset_3.map(apply_template2, batched=True)
# # Concatenate both datasets
# dataset = concatenate_datasets([dataset_1, dataset_2, dataset_3])
# ###############################################################################
# # 3. SFTTrainer and Training Arguments (with checkpointing)
# ###############################################################################
# training_args = TrainingArguments(
# learning_rate=3e-4,
# lr_scheduler_type="linear",
# per_device_train_batch_size=8,
# gradient_accumulation_steps=2,
# num_train_epochs=1,
# fp16=not is_bfloat16_supported(),
# bf16=is_bfloat16_supported(),
# logging_steps=1,
# optim="adamw_8bit",
# weight_decay=0.01,
# warmup_steps=10,
# output_dir="drive/MyDrive/Ribo/model-checkpoints",
# seed=0,
# report_to="none",
# )
# trainer = SFTTrainer(
# model=model,
# tokenizer=tokenizer,
# train_dataset=dataset,
# dataset_text_field="text",
# max_seq_length=max_seq_length,
# dataset_num_proc=2,
# packing=True,
# args=training_args,
# )
# ###############################################################################
# # 4. Train and Save Checkpoints
# ###############################################################################
# trainer.train()
# # After every `save_steps` steps, a checkpoint is saved in `output/checkpoint-*`.
# # You can resume training from there by setting `resume_from_checkpoint`.
# ###############################################################################
# # 5. Convert to Inference Model
# ###############################################################################
# model = FastLanguageModel.for_inference(model)
# ###############################################################################
# # 7. Save & Push Final Merged Model
# ###############################################################################
# # Save model merged (16-bit) locally
# model.save_pretrained_merged(
# "drive/MyDrive/Ribo/model",
# tokenizer,
# save_method="merged_16bit"
# )
model, tokenizer = FastLanguageModel.from_pretrained("./")
###############################################################################
# 6. Example Inference with TextStreamer
###############################################################################
messages = [
{
"from": "system",
"value": """
Available tools:
[
{
"type": "function",
"function": {
"name": "get_current_date",
"description": "Returns the current date in the format specified",
"parameters": {
"type": "object",
"required": ["format"],
"properties": {
"format": {
"type": "string",
"description": "will format the date in the format specified MM/DD/YY or similar"
}
}
}
}
}
]
"""
},
{"from": "human", "value": "What is the current date?"},
]
formatted_text = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
)
# If your GPU has limited memory, you might need smaller max_new_tokens
# or streaming logic
text_streamer = TextStreamer(tokenizer)
output = model.generate(
input_ids=formatted_text["input_ids"],
attention_mask=formatted_text["attention_mask"],
streamer=text_streamer,
max_new_tokens=4096,
use_cache=True
)