Update app.py
Browse files
app.py
CHANGED
|
@@ -33,9 +33,14 @@ model, tokenizer = FastLanguageModel.from_pretrained(
|
|
| 33 |
)
|
| 34 |
print("Model and tokenizer loaded successfully.")
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
print("Configuring PEFT model...")
|
| 37 |
model = FastLanguageModel.get_peft_model(
|
| 38 |
-
model,
|
| 39 |
r=16,
|
| 40 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 41 |
lora_alpha=16,
|
|
@@ -51,24 +56,18 @@ print("PEFT model configured.")
|
|
| 51 |
# Updated alpaca_prompt for different types
|
| 52 |
alpaca_prompt = {
|
| 53 |
"learning_from": """Below is a CVE definition.
|
| 54 |
-
|
| 55 |
### CVE definition:
|
| 56 |
{}
|
| 57 |
-
|
| 58 |
### detail CVE:
|
| 59 |
{}""",
|
| 60 |
"definition": """Below is a definition about software vulnerability. Explain it.
|
| 61 |
-
|
| 62 |
### Definition:
|
| 63 |
{}
|
| 64 |
-
|
| 65 |
### Explanation:
|
| 66 |
{}""",
|
| 67 |
"code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
|
| 68 |
-
|
| 69 |
### Code Snippet:
|
| 70 |
{}
|
| 71 |
-
|
| 72 |
### Vulnerability solution:
|
| 73 |
{}"""
|
| 74 |
}
|
|
@@ -111,7 +110,7 @@ print("Formatting function applied.")
|
|
| 111 |
|
| 112 |
print("Initializing trainer...")
|
| 113 |
trainer = SFTTrainer(
|
| 114 |
-
model=model,
|
| 115 |
tokenizer=tokenizer,
|
| 116 |
train_dataset=dataset,
|
| 117 |
dataset_text_field="text",
|
|
@@ -145,11 +144,16 @@ num += 1
|
|
| 145 |
uploads_models = f"cybersentinal-3.0"
|
| 146 |
|
| 147 |
print("Saving the trained model...")
|
| 148 |
-
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
|
| 149 |
print("Model saved successfully.")
|
| 150 |
|
| 151 |
print("Pushing the model to the hub...")
|
| 152 |
-
model.push_to_hub_merged(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
uploads_models,
|
| 154 |
tokenizer,
|
| 155 |
save_method="merged_16bit",
|
|
|
|
| 33 |
)
|
| 34 |
print("Model and tokenizer loaded successfully.")
|
| 35 |
|
| 36 |
+
# Wrap the model in DataParallel to use all GPUs
|
| 37 |
+
if torch.cuda.device_count() > 1:
|
| 38 |
+
print(f"Using {torch.cuda.device_count()} GPUs!")
|
| 39 |
+
model = torch.nn.DataParallel(model)
|
| 40 |
+
|
| 41 |
print("Configuring PEFT model...")
|
| 42 |
model = FastLanguageModel.get_peft_model(
|
| 43 |
+
model.module if isinstance(model, torch.nn.DataParallel) else model,
|
| 44 |
r=16,
|
| 45 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 46 |
lora_alpha=16,
|
|
|
|
| 56 |
# Updated alpaca_prompt for different types
|
| 57 |
alpaca_prompt = {
|
| 58 |
"learning_from": """Below is a CVE definition.
|
|
|
|
| 59 |
### CVE definition:
|
| 60 |
{}
|
|
|
|
| 61 |
### detail CVE:
|
| 62 |
{}""",
|
| 63 |
"definition": """Below is a definition about software vulnerability. Explain it.
|
|
|
|
| 64 |
### Definition:
|
| 65 |
{}
|
|
|
|
| 66 |
### Explanation:
|
| 67 |
{}""",
|
| 68 |
"code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
|
|
|
|
| 69 |
### Code Snippet:
|
| 70 |
{}
|
|
|
|
| 71 |
### Vulnerability solution:
|
| 72 |
{}"""
|
| 73 |
}
|
|
|
|
| 110 |
|
| 111 |
print("Initializing trainer...")
|
| 112 |
trainer = SFTTrainer(
|
| 113 |
+
model=model.module if isinstance(model, torch.nn.DataParallel) else model,
|
| 114 |
tokenizer=tokenizer,
|
| 115 |
train_dataset=dataset,
|
| 116 |
dataset_text_field="text",
|
|
|
|
| 144 |
uploads_models = f"cybersentinal-3.0"
|
| 145 |
|
| 146 |
print("Saving the trained model...")
|
| 147 |
+
model.module.save_pretrained_merged("model", tokenizer, save_method="merged_16bit") if isinstance(model, torch.nn.DataParallel) else model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
|
| 148 |
print("Model saved successfully.")
|
| 149 |
|
| 150 |
print("Pushing the model to the hub...")
|
| 151 |
+
model.module.push_to_hub_merged(
|
| 152 |
+
uploads_models,
|
| 153 |
+
tokenizer,
|
| 154 |
+
save_method="merged_16bit",
|
| 155 |
+
token=hf_token
|
| 156 |
+
) if isinstance(model, torch.nn.DataParallel) else model.push_to_hub_merged(
|
| 157 |
uploads_models,
|
| 158 |
tokenizer,
|
| 159 |
save_method="merged_16bit",
|