import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
# Model configuration
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
ADAPTER_MODEL = "mistral-hackaton-2026/Mistral_SmartContract_Security"
print("Loading model... This may take a few minutes on first run.")
# Load tokenizer from base model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load base model (with quantization only if GPU available)
if torch.cuda.is_available():
try:
print("Loading with 4-bit quantization...")
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
except Exception as e:
print(f"Quantization failed ({e}), loading in float16...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
else:
print("Loading on CPU (slower but works without GPU)...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Load PEFT adapter
print("Loading fine-tuned adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
model.eval() # Set to evaluation mode
print("Model loaded successfully!")
# Sample contracts for demo
SAMPLE_CONTRACTS = {
"Reentrancy Attack": '''// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract VulnerableBank {
mapping(address => uint256) public balances;
function deposit() public payable {
balances[msg.sender] += msg.value;
}
function withdraw(uint256 amount) public {
require(balances[msg.sender] >= amount);
(bool success, ) = msg.sender.call{value: amount}("");
require(success);
balances[msg.sender] -= amount; // VULNERABLE: State updated after external call
}
}''',
"Integer Overflow": '''// SPDX-License-Identifier: MIT
pragma solidity 0.7.0;
contract VulnerableToken {
mapping(address => uint256) public balances;
function transfer(address to, uint256 amount) public {
balances[msg.sender] -= amount; // VULNERABLE: No underflow check
balances[to] += amount;
}
}''',
"Access Control Missing": '''// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract VulnerableAdmin {
address public owner;
constructor() {
owner = msg.sender;
}
function setOwner(address newOwner) public { // VULNERABLE: Anyone can change owner
owner = newOwner;
}
}''',
"Secure Contract": '''// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
contract SecureVault {
mapping(address => uint256) public balances;
bool private locked;
modifier noReentrant() {
require(!locked, "No reentrancy");
locked = true;
_;
locked = false;
}
function withdraw(uint256 amount) public noReentrant {
require(balances[msg.sender] >= amount);
balances[msg.sender] -= amount;
(bool success, ) = msg.sender.call{value: amount}("");
require(success);
}
}'''
}
def analyze_contract(contract_code, show_tokens=False):
"""Analyze smart contract for vulnerabilities"""
if not contract_code.strip():
return "⚠️ Please enter a smart contract to analyze."
# Prepare input
input_text = f"""[INST] Analyze this contract and provide structured security report:
{contract_code} [/INST]
"""
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=2048)
if torch.cuda.is_available():
inputs = inputs.to("cuda")
# Generate analysis
model.eval()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.3,
do_sample=True,
top_p=0.95
)
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
analysis = response.split('[/INST]')[-1].strip()
# Format output
if show_tokens:
output = f"### 🔍 Security Analysis (with Custom Tokens)\n\n```\n{analysis}\n```"
else:
# Clean up for display
clean_analysis = analysis.replace('', '').strip()
# Check for vulnerabilities
is_vulnerable = any([
'' in analysis,
'' in analysis,
'