import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import os # Model configuration BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" ADAPTER_MODEL = "mistral-hackaton-2026/Mistral_SmartContract_Security" print("Loading model... This may take a few minutes on first run.") # Load tokenizer from base model tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load base model (with quantization only if GPU available) if torch.cuda.is_available(): try: print("Loading with 4-bit quantization...") from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=False, ) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True ) except Exception as e: print(f"Quantization failed ({e}), loading in float16...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True ) else: print("Loading on CPU (slower but works without GPU)...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", torch_dtype=torch.float32, trust_remote_code=True, low_cpu_mem_usage=True ) # Load PEFT adapter print("Loading fine-tuned adapter...") model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL) model.eval() # Set to evaluation mode print("Model loaded successfully!") # Sample contracts for demo SAMPLE_CONTRACTS = { "Reentrancy Attack": '''// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; contract VulnerableBank { mapping(address => uint256) public balances; function deposit() public payable { balances[msg.sender] += msg.value; } function withdraw(uint256 amount) public { require(balances[msg.sender] >= amount); (bool success, ) = msg.sender.call{value: amount}(""); require(success); balances[msg.sender] -= amount; // VULNERABLE: State updated after external call } }''', "Integer Overflow": '''// SPDX-License-Identifier: MIT pragma solidity 0.7.0; contract VulnerableToken { mapping(address => uint256) public balances; function transfer(address to, uint256 amount) public { balances[msg.sender] -= amount; // VULNERABLE: No underflow check balances[to] += amount; } }''', "Access Control Missing": '''// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; contract VulnerableAdmin { address public owner; constructor() { owner = msg.sender; } function setOwner(address newOwner) public { // VULNERABLE: Anyone can change owner owner = newOwner; } }''', "Secure Contract": '''// SPDX-License-Identifier: MIT pragma solidity ^0.8.0; contract SecureVault { mapping(address => uint256) public balances; bool private locked; modifier noReentrant() { require(!locked, "No reentrancy"); locked = true; _; locked = false; } function withdraw(uint256 amount) public noReentrant { require(balances[msg.sender] >= amount); balances[msg.sender] -= amount; (bool success, ) = msg.sender.call{value: amount}(""); require(success); } }''' } def analyze_contract(contract_code, show_tokens=False): """Analyze smart contract for vulnerabilities""" if not contract_code.strip(): return "⚠️ Please enter a smart contract to analyze." # Prepare input input_text = f"""[INST] Analyze this contract and provide structured security report: {contract_code} [/INST] """ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=2048) if torch.cuda.is_available(): inputs = inputs.to("cuda") # Generate analysis model.eval() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=300, temperature=0.3, do_sample=True, top_p=0.95 ) response = tokenizer.decode(outputs[0], skip_special_tokens=False) analysis = response.split('[/INST]')[-1].strip() # Format output if show_tokens: output = f"### 🔍 Security Analysis (with Custom Tokens)\n\n```\n{analysis}\n```" else: # Clean up for display clean_analysis = analysis.replace('', '').strip() # Check for vulnerabilities is_vulnerable = any([ '' in analysis, '' in analysis, '