smart-contract-audit-env / inference.py
Ismail131's picture
Upload folder using huggingface_hub
ae29d48 verified
import os
import json
import requests
from openai import OpenAI
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Required environment variables for OpenEnv baseline agents
# These will be provided by the hackathon evaluation environment
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o")
HF_TOKEN = os.getenv("HF_TOKEN", "")
# Hugging Face Space URL (Replace with your actual URL after deployment)
# For local testing, use http://localhost:8000
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
def call_llm(contract_code, task_description):
"""Ask the LLM to audit a smart contract"""
print(f"Auditing contract with {MODEL_NAME}...")
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system",
"content": "You are a specialized Solidity security auditor. Your goal is to find vulnerabilities in smart contracts."},
{"role": "user",
"content": f"Audit this contract:\n\n{contract_code}\n\n"
f"Task: {task_description}\n\n"
"Return valid JSON ONLY in this format: {\"analysis\": \"...\", "
"\"vulnerabilities\": [{\"type\": \"...\", "
"\"line\": N, \"severity\": \"...\", "
"\"description\": \"...\"}], "
"\"suggested_fixes\": [\"...\"]}"
}
],
temperature=0.1,
response_format={ "type": "json_object" } # Ensure JSON output
)
content = response.choices[0].message.content
try:
return json.loads(content)
except json.JSONDecodeError:
print(f"Failed to parse LLM response: {content}")
return {
"analysis": "error",
"vulnerabilities": [],
"suggested_fixes": []
}
def run_task(task_id):
"""Run one task against the environment"""
print(f"\n--- Starting Task: {task_id} ---")
# 1. Reset environment for this task
r = requests.post(f"{ENV_URL}/reset",
json={"task_id": task_id})
if r.status_code != 200:
print(f"Error resetting environment: {r.text}")
return 0.0
obs = r.json()
# 2. Call LLM to audit the contract
result = call_llm(obs["contract_code"],
obs["task_description"])
# 3. Step the environment with the agent's findings
r = requests.post(f"{ENV_URL}/step", json={
"analysis": result.get("analysis", ""),
"vulnerabilities": result.get("vulnerabilities", []),
"suggested_fixes": result.get("suggested_fixes", [])
})
if r.status_code != 200:
print(f"Error stepping environment: {r.text}")
return 0.0
step_obs = r.json()
reward = step_obs.get("reward", 0.0)
feedback = step_obs.get("feedback", "No feedback")
print(f"Task: {task_id} | Score: {reward:.2f} | Feedback: {feedback}")
return reward
if __name__ == "__main__":
# Ensure environment is running
try:
requests.get(ENV_URL)
except requests.exceptions.ConnectionError:
print(f"Warning: Could not connect to {ENV_URL}. Make sure your environment server is running.")
tasks = ["easy_reentrancy", "medium_multi_vuln", "hard_full_audit"]
total_score = 0
for task_id in tasks:
score = run_task(task_id)
total_score += score
print(f"\nFinal Total Score: {total_score:.2f} / {len(tasks)}.00")