Spaces:
Build error
Build error
| import os | |
| import torch | |
| from unsloth import FastLanguageModel, is_bfloat16_supported | |
| from trl import SFTTrainer | |
| from transformers import TrainingArguments | |
| from datasets import load_dataset | |
| import gradio as gr | |
| import json | |
| from huggingface_hub import HfApi | |
| max_seq_length = 4096 | |
| dtype = None | |
| load_in_4bit = True | |
| hf_token = os.getenv("HF_TOKEN") | |
| current_num = os.getenv("NUM") | |
| print(f"stage ${current_num}") | |
| api = HfApi(token=hf_token) | |
| models = "dad1909/cybersentinal-2.0" | |
| print("Starting model and tokenizer loading...") | |
| # Load the model and tokenizer | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=models, | |
| max_seq_length=max_seq_length, | |
| dtype=dtype, | |
| load_in_4bit=load_in_4bit, | |
| token=hf_token | |
| ) | |
| print("Model and tokenizer loaded successfully.") | |
| print("Configuring PEFT model...") | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| use_rslora=False, | |
| loftq_config=None, | |
| ) | |
| print("PEFT model configured.") | |
| # Updated alpaca_prompt for different types | |
| alpaca_prompt = { | |
| "learning_from": """Below is a CVE definition. | |
| ### CVE definition: | |
| {} | |
| ### detail CVE: | |
| {}""", | |
| "definition": """Below is a definition about software vulnerability. Explain it. | |
| ### Definition: | |
| {} | |
| ### Explanation: | |
| {}""", | |
| "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability. | |
| ### Code Snippet: | |
| {} | |
| ### Vulnerability solution: | |
| {}""" | |
| } | |
| EOS_TOKEN = tokenizer.eos_token | |
| def detect_prompt_type(instruction): | |
| if instruction.startswith("what is code vulnerable of this code:"): | |
| return "code_vulnerability" | |
| elif instruction.startswith("Learning from"): | |
| return "learning_from" | |
| elif instruction.startswith("what is"): | |
| return "definition" | |
| else: | |
| return "unknown" | |
| def formatting_prompts_func(examples): | |
| instructions = examples["instruction"] | |
| outputs = examples["output"] | |
| texts = [] | |
| for instruction, output in zip(instructions, outputs): | |
| prompt_type = detect_prompt_type(instruction) | |
| if prompt_type in alpaca_prompt: | |
| prompt = alpaca_prompt[prompt_type].format(instruction, output) | |
| else: | |
| prompt = instruction + "\n\n" + output | |
| text = prompt + EOS_TOKEN | |
| texts.append(text) | |
| return {"text": texts} | |
| print("Loading dataset...") | |
| dataset = load_dataset("admincybers2/DSV", split="train") | |
| print("Dataset loaded successfully.") | |
| print("Applying formatting function to the dataset...") | |
| dataset = dataset.map(formatting_prompts_func, batched=True) | |
| print("Formatting function applied.") | |
| print("Initializing trainer...") | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=dataset, | |
| dataset_text_field="text", | |
| max_seq_length=max_seq_length, | |
| dataset_num_proc=2, | |
| packing=False, | |
| args=TrainingArguments( | |
| per_device_train_batch_size=5, | |
| gradient_accumulation_steps=5, | |
| learning_rate=2e-4, | |
| fp16=not is_bfloat16_supported(), | |
| bf16=is_bfloat16_supported(), | |
| warmup_steps=5, | |
| logging_steps=10, | |
| max_steps=200, | |
| optim="adamw_8bit", | |
| weight_decay=0.01, | |
| lr_scheduler_type="linear", | |
| seed=3407, | |
| output_dir="outputs" | |
| ), | |
| ) | |
| print("Trainer initialized.") | |
| print("Starting training...") | |
| trainer_stats = trainer.train() | |
| print("Training completed.") | |
| num = int(current_num) | |
| num += 1 | |
| up = "sentinal-2" | |
| print("Saving the trained model...") | |
| model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit") | |
| print("Model saved successfully.") | |
| print("Pushing the model to the hub...") | |
| model.push_to_hub_merged( | |
| up, | |
| tokenizer, | |
| save_method="merged_16bit", | |
| token=hf_token | |
| ) | |
| print("Model pushed to hub successfully.") | |
| api.delete_space_variable(repo_id="admincybers2/CyberController", key="NUM") | |
| api.add_space_variable(repo_id="admincybers2/CyberController", key="NUM", value=str(num)) |