Spaces:
Paused
Paused
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForCausalLM, LlamaTokenizer | |
| from peft import PeftModel | |
| import gc | |
| def load_model(): | |
| model_name = "peterxyz/detect-llama-34b" | |
| # Load the specific tokenizer type used by the model checkpoint | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| # Check if CUDA is available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| st.info(f"Using device: {device}") | |
| # Clear memory | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Load model with appropriate settings based on device | |
| if device == "cuda": | |
| from transformers import BitsAndBytesConfig | |
| import bitsandbytes as bnb | |
| nf4_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| model_nf4 = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=nf4_config, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| model = PeftModel.from_pretrained(model_nf4, model_name) | |
| else: | |
| # For CPU, load with reduced precision but without 4-bit quantization | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| device_map={"": device}, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| return model, tokenizer, device | |
| def analyze_contract(contract_code, model, tokenizer, device): | |
| prompt = f"{contract_code}\n\nidentify vulnerability of this code given above" | |
| # Add padding token if needed | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048 | |
| ).to(device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=1024, | |
| temperature=0.7, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Set page config | |
| st.set_page_config( | |
| page_title="Smart Contract Vulnerability Detector", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| # Main app | |
| st.title("π Smart Contract Vulnerability Detector") | |
| st.markdown(""" | |
| This app analyzes Solidity smart contracts for potential vulnerabilities using a fine-tuned LLaMA model. | |
| Simply paste your smart contract code below and click 'Analyze'. | |
| """) | |
| # Add a loading message while initializing | |
| if 'model_loaded' not in st.session_state: | |
| st.session_state.model_loaded = False | |
| # Initialize session state for the model | |
| if not st.session_state.model_loaded: | |
| try: | |
| with st.spinner('Loading model... This might take a few minutes...'): | |
| st.session_state.model, st.session_state.tokenizer, st.session_state.device = load_model() | |
| st.session_state.model_loaded = True | |
| st.success('Model loaded successfully!') | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| st.stop() | |
| # Create the main interface | |
| contract_code = st.text_area( | |
| "Paste your Solidity contract code here:", | |
| height=300, | |
| placeholder="pragma solidity ^0.5.0;\n\ncontract YourContract {\n // Your code here\n}" | |
| ) | |
| col1, col2 = st.columns([1, 4]) | |
| with col1: | |
| analyze_button = st.button("Analyze Contract", type="primary") | |
| with col2: | |
| load_sample = st.button("Load Sample Contract") | |
| # Sample contract button | |
| if load_sample: | |
| contract_code = """pragma solidity ^0.5.0; | |
| contract ModifierEntrancy { | |
| mapping (address => uint) public tokenBalance; | |
| string constant name = "Nu Token"; | |
| Bank bank; | |
| constructor() public{ | |
| bank = new Bank(); | |
| } | |
| function airDrop() hasNoBalance supportsToken public{ | |
| tokenBalance[msg.sender] += 20; | |
| } | |
| modifier supportsToken() { | |
| require(keccak256(abi.encodePacked("Nu Token")) == bank.supportsToken()); | |
| _; | |
| } | |
| modifier hasNoBalance { | |
| require(tokenBalance[msg.sender] == 0); | |
| _; | |
| } | |
| } | |
| contract Bank{ | |
| function supportsToken() external returns(bytes32) { | |
| return keccak256(abi.encodePacked("Nu Token")); | |
| } | |
| }""" | |
| st.session_state.contract_code = contract_code | |
| st.experimental_rerun() | |
| # Analysis section | |
| if analyze_button and contract_code: | |
| try: | |
| with st.spinner('Analyzing contract...'): | |
| analysis = analyze_contract( | |
| contract_code, | |
| st.session_state.model, | |
| st.session_state.tokenizer, | |
| st.session_state.device | |
| ) | |
| st.subheader("Analysis Results") | |
| # Create an expandable section for the analysis | |
| with st.expander("View Full Analysis", expanded=True): | |
| st.markdown(analysis) | |
| except Exception as e: | |
| st.error(f"An error occurred during analysis: {str(e)}") | |
| st.markdown("**Debug Information:**") | |
| st.code(str(e)) | |
| elif analyze_button: | |
| st.warning("Please enter some contract code to analyze.") | |
| # Add footer with information | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style='text-align: center'> | |
| <p>Built with Streamlit and Hugging Face Transformers</p> | |
| <p>Model: peterxyz/detect-llama-34b</p> | |
| </div> | |
| """, unsafe_allow_html=True) |