| from datasets import load_dataset, concatenate_datasets | |
| from transformers import TrainingArguments, TextStreamer | |
| from trl import SFTTrainer | |
| from unsloth.chat_templates import get_chat_template | |
| from unsloth import FastLanguageModel, is_bfloat16_supported | |
| # ############################################################################### | |
| # # 1. Load/Initialize Model and Tokenizer | |
| # ############################################################################### | |
| # max_seq_length = 2048 | |
| # model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" | |
| # model, tokenizer = FastLanguageModel.from_pretrained( | |
| # model_name=model_name, | |
| # max_seq_length=max_seq_length, | |
| # load_in_4bit=True, | |
| # dtype=None, | |
| # ) | |
| # model = FastLanguageModel.get_peft_model( | |
| # model, | |
| # r=16, | |
| # lora_alpha=16, | |
| # lora_dropout=0, | |
| # target_modules=[ | |
| # "q_proj", "k_proj", "v_proj", "up_proj", "down_proj", | |
| # "o_proj", "gate_proj" | |
| # ], | |
| # use_rslora=True, | |
| # use_gradient_checkpointing="unsloth" | |
| # ) | |
| # # Prepare the tokenizer for "chatml" format | |
| # tokenizer = get_chat_template( | |
| # tokenizer, | |
| # mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"}, | |
| # chat_template="chatml", | |
| # ) | |
| # ############################################################################### | |
| # # 2. Dataset Loading and Caching | |
| # ############################################################################### | |
| # # The user’s custom function to apply chat template: | |
| # def apply_template(examples): | |
| # messages_batch = examples["conversations"] | |
| # texts = [] | |
| # for message in messages_batch: | |
| # text = tokenizer.apply_chat_template( | |
| # message, | |
| # tokenize=False, | |
| # add_generation_prompt=False | |
| # ) | |
| # texts.append(text) | |
| # return {"text": texts} | |
| # def apply_template2(examples): | |
| # import json | |
| # conversation_batch = examples["conversation"] | |
| # tools_batch = examples["tools"] | |
| # texts = [] | |
| # for i, conversation_json_str in enumerate(conversation_batch): | |
| # # 1) Load conversation & tools: | |
| # thread = json.loads(conversation_json_str) | |
| # tools_data = json.loads(tools_batch[i]) | |
| # # 2) Convert "arguments" to "parameters" | |
| # for tool in tools_data: | |
| # if "arguments" in tool: | |
| # tool["parameters"] = tool["arguments"] | |
| # # 3) Create system prompt | |
| # system_prompt = { | |
| # "from": "system", | |
| # "value": ( | |
| # "You are a function calling AI model. You are provided with " | |
| # "function signatures within <tools> </tools> XML tags. Don't make " | |
| # "assumptions about what values to plug into functions.\n" | |
| # f"<tools>{json.dumps(tools_data)}</tools>" | |
| # ) | |
| # } | |
| # # 4) Build new conversation | |
| # clean_thread = [system_prompt] | |
| # for msg in thread: | |
| # # Possibly rename "role": "tool call" to something else | |
| # if msg["role"] == "tool call": | |
| # msg["role"] = "gtp" | |
| # # The code below ensures "value" is <tool_call> ... </tool_call> | |
| # if not isinstance(msg, dict): | |
| # # If it's not a dict, forcibly convert to dict | |
| # item = json.dumps({"type":"function", "function": msg['content']}) | |
| # clean_thread.append({ | |
| # "from": msg["role"], | |
| # "value": f"<tool_call>{item}</tool_call>" | |
| # }) | |
| # else: | |
| # item = json.dumps({"type":"function", "function": msg['content']}) | |
| # clean_thread.append({ | |
| # "from": msg["role"], | |
| # "value": f"<tool_call>{item}</tool_call>" | |
| # }) | |
| # # 6) PASS THE LIST (NOT the JSON string) to apply_chat_template | |
| # text = tokenizer.apply_chat_template( | |
| # clean_thread, | |
| # tokenize=False, | |
| # add_generation_prompt=False | |
| # ) | |
| # texts.append(text) | |
| # return {"text": texts} | |
| # tool_intro = "You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions." | |
| # # If you want a local cache file, specify cache_file_name | |
| # dataset_1 = load_dataset( | |
| # "interstellarninja/tool-calls-sharegpt", | |
| # split="train", | |
| # ) | |
| # # Load second dataset | |
| # dataset_2 = load_dataset( | |
| # "interstellarninja/tool-calls-multiturn", | |
| # split="train", | |
| # ) | |
| # dataset_3 = load_dataset( | |
| # "BitAgent/tool_calling", | |
| # split="train", | |
| # ) | |
| # dataset_1 = dataset_1.map(apply_template, batched=True) | |
| # dataset_2 = dataset_2.map(apply_template, batched=True) | |
| # dataset_3 = dataset_3.map(apply_template2, batched=True) | |
| # # Concatenate both datasets | |
| # dataset = concatenate_datasets([dataset_1, dataset_2, dataset_3]) | |
| # ############################################################################### | |
| # # 3. SFTTrainer and Training Arguments (with checkpointing) | |
| # ############################################################################### | |
| # training_args = TrainingArguments( | |
| # learning_rate=3e-4, | |
| # lr_scheduler_type="linear", | |
| # per_device_train_batch_size=8, | |
| # gradient_accumulation_steps=2, | |
| # num_train_epochs=1, | |
| # fp16=not is_bfloat16_supported(), | |
| # bf16=is_bfloat16_supported(), | |
| # logging_steps=1, | |
| # optim="adamw_8bit", | |
| # weight_decay=0.01, | |
| # warmup_steps=10, | |
| # output_dir="drive/MyDrive/Ribo/model-checkpoints", | |
| # seed=0, | |
| # report_to="none", | |
| # ) | |
| # trainer = SFTTrainer( | |
| # model=model, | |
| # tokenizer=tokenizer, | |
| # train_dataset=dataset, | |
| # dataset_text_field="text", | |
| # max_seq_length=max_seq_length, | |
| # dataset_num_proc=2, | |
| # packing=True, | |
| # args=training_args, | |
| # ) | |
| # ############################################################################### | |
| # # 4. Train and Save Checkpoints | |
| # ############################################################################### | |
| # trainer.train() | |
| # # After every `save_steps` steps, a checkpoint is saved in `output/checkpoint-*`. | |
| # # You can resume training from there by setting `resume_from_checkpoint`. | |
| # ############################################################################### | |
| # # 5. Convert to Inference Model | |
| # ############################################################################### | |
| # model = FastLanguageModel.for_inference(model) | |
| # ############################################################################### | |
| # # 7. Save & Push Final Merged Model | |
| # ############################################################################### | |
| # # Save model merged (16-bit) locally | |
| # model.save_pretrained_merged( | |
| # "drive/MyDrive/Ribo/model", | |
| # tokenizer, | |
| # save_method="merged_16bit" | |
| # ) | |
| model, tokenizer = FastLanguageModel.from_pretrained("./") | |
| ############################################################################### | |
| # 6. Example Inference with TextStreamer | |
| ############################################################################### | |
| messages = [ | |
| { | |
| "from": "system", | |
| "value": """ | |
| Available tools: | |
| [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_current_date", | |
| "description": "Returns the current date in the format specified", | |
| "parameters": { | |
| "type": "object", | |
| "required": ["format"], | |
| "properties": { | |
| "format": { | |
| "type": "string", | |
| "description": "will format the date in the format specified MM/DD/YY or similar" | |
| } | |
| } | |
| } | |
| } | |
| } | |
| ] | |
| """ | |
| }, | |
| {"from": "human", "value": "What is the current date?"}, | |
| ] | |
| formatted_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ) | |
| # If your GPU has limited memory, you might need smaller max_new_tokens | |
| # or streaming logic | |
| text_streamer = TextStreamer(tokenizer) | |
| output = model.generate( | |
| input_ids=formatted_text["input_ids"], | |
| attention_mask=formatted_text["attention_mask"], | |
| streamer=text_streamer, | |
| max_new_tokens=4096, | |
| use_cache=True | |
| ) |