sammy786's picture
Fix: Conditional BitsAndBytesConfig import with fallback
ea8bc10 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
# Model configuration
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
ADAPTER_MODEL = "mistral-hackaton-2026/Mistral_SmartContract_Security"
print("Loading model... This may take a few minutes on first run.")
# Load tokenizer from base model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load base model (with quantization only if GPU available)
if torch.cuda.is_available():
try:
print("Loading with 4-bit quantization...")
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
except Exception as e:
print(f"Quantization failed ({e}), loading in float16...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
else:
print("Loading on CPU (slower but works without GPU)...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Load PEFT adapter
print("Loading fine-tuned adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
model.eval() # Set to evaluation mode
print("Model loaded successfully!")
# Sample contracts for demo
SAMPLE_CONTRACTS = {
"Reentrancy Attack": '''// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract VulnerableBank {
mapping(address => uint256) public balances;
function deposit() public payable {
balances[msg.sender] += msg.value;
}
function withdraw(uint256 amount) public {
require(balances[msg.sender] >= amount);
(bool success, ) = msg.sender.call{value: amount}("");
require(success);
balances[msg.sender] -= amount; // VULNERABLE: State updated after external call
}
}''',
"Integer Overflow": '''// SPDX-License-Identifier: MIT
pragma solidity 0.7.0;
contract VulnerableToken {
mapping(address => uint256) public balances;
function transfer(address to, uint256 amount) public {
balances[msg.sender] -= amount; // VULNERABLE: No underflow check
balances[to] += amount;
}
}''',
"Access Control Missing": '''// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract VulnerableAdmin {
address public owner;
constructor() {
owner = msg.sender;
}
function setOwner(address newOwner) public { // VULNERABLE: Anyone can change owner
owner = newOwner;
}
}''',
"Secure Contract": '''// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract SecureVault {
mapping(address => uint256) public balances;
bool private locked;
modifier noReentrant() {
require(!locked, "No reentrancy");
locked = true;
_;
locked = false;
}
function withdraw(uint256 amount) public noReentrant {
require(balances[msg.sender] >= amount);
balances[msg.sender] -= amount;
(bool success, ) = msg.sender.call{value: amount}("");
require(success);
}
}'''
}
def analyze_contract(contract_code, show_tokens=False):
"""Analyze smart contract for vulnerabilities"""
if not contract_code.strip():
return "⚠️ Please enter a smart contract to analyze."
# Prepare input
input_text = f"""<s>[INST] Analyze this contract and provide structured security report:
{contract_code} [/INST]
"""
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=2048)
if torch.cuda.is_available():
inputs = inputs.to("cuda")
# Generate analysis
model.eval()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.3,
do_sample=True,
top_p=0.95
)
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
analysis = response.split('[/INST]')[-1].strip()
# Format output
if show_tokens:
output = f"### πŸ” Security Analysis (with Custom Tokens)\n\n```\n{analysis}\n```"
else:
# Clean up for display
clean_analysis = analysis.replace('</s>', '').strip()
# Check for vulnerabilities
is_vulnerable = any([
'<CRITICAL>' in analysis,
'<HIGH>' in analysis,
'<VULN_' in analysis
])
if is_vulnerable:
output = f"""### 🚨 VULNERABILITY DETECTED
{clean_analysis}
---
**Model**: Mistral-7B Fine-tuned with 38 Custom Security Tokens
**Accuracy**: 99.6% | **Precision**: 100% | **Recall**: 99.3%
"""
else:
output = f"""### βœ… CONTRACT APPEARS SECURE
{clean_analysis}
---
**Model**: Mistral-7B Fine-tuned with 38 Custom Security Tokens
**Accuracy**: 99.6% | **Precision**: 100% | **Recall**: 99.3%
"""
return output
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="Smart Contract Security Analyzer") as demo:
gr.Markdown("""
# πŸ” Smart Contract Security Analyzer
### Powered by Fine-Tuned Mistral-7B with Custom Security Tokens
**Hackathon Submission Features:**
- ✨ 38 custom security tokens (NEW capability)
- πŸ“Š 99.6% accuracy on 30K balanced dataset
- 🎯 100% precision (zero false positives)
- πŸ“ˆ +28.6% improvement over base Mistral
- πŸ—οΈ Structured XML-style vulnerability reports
""")
with gr.Row():
with gr.Column(scale=2):
contract_input = gr.Code(
label="πŸ“ Paste Solidity Contract Here",
language="solidity",
lines=20,
placeholder="// SPDX-License-Identifier: MIT\npragma solidity ^0.8.0;\n\ncontract MyContract {\n // Your contract code here\n}"
)
with gr.Row():
analyze_btn = gr.Button("πŸ” Analyze Contract", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
show_tokens_checkbox = gr.Checkbox(
label="Show custom security tokens in output",
value=False
)
with gr.Column(scale=1):
gr.Markdown("### πŸ“š Sample Contracts")
sample_dropdown = gr.Dropdown(
choices=list(SAMPLE_CONTRACTS.keys()),
label="Load Example",
value=None
)
load_sample_btn = gr.Button("Load Sample", size="sm")
output_display = gr.Markdown(label="Analysis Result")
# Performance stats
gr.Markdown("""
---
### πŸ“Š Model Performance Metrics
| Metric | Base Mistral | Fine-Tuned | Improvement |
|--------|-------------|------------|-------------|
| **Accuracy** | 71.0% | **99.6%** | +28.6% |
| **Precision** | 64.2% | **100.0%** | +35.8% |
| **Recall** | 100.0% | 99.3% | -0.7% |
| **F1 Score** | 0.782 | **0.996** | +0.214 |
| **Custom Tokens** | 0/38 | **25/38** | NEW! |
| **Structured Output** | No | **Yes** | NEW! |
**Training Details:**
- 30,000 balanced samples (50% vulnerable, 50% safe)
- 6 vulnerability types: Reentrancy, Overflow, Access Control, Unchecked Call, DoS, Timestamp
- 4-bit quantization with LoRA (1.1% trainable params)
- 5.5 hours training on Google Colab G4 GPU
""")
# Event handlers
def load_sample(sample_name):
if sample_name:
return SAMPLE_CONTRACTS[sample_name]
return ""
analyze_btn.click(
fn=analyze_contract,
inputs=[contract_input, show_tokens_checkbox],
outputs=output_display
)
load_sample_btn.click(
fn=load_sample,
inputs=sample_dropdown,
outputs=contract_input
)
clear_btn.click(
fn=lambda: ("", ""),
outputs=[contract_input, output_display]
)
# Launch demo
if __name__ == "__main__":
demo.launch()