AuditAgent / app /models /analyzer.py
Parsa2025AI's picture
loading fine-tuned llm
e5348e1 verified
import os
import logging
from typing import Optional, Dict, Any
from app.models.patterns import analyze_with_patterns
logger = logging.getLogger(__name__)
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "")
HF_TOKEN = os.getenv("HF_TOKEN", "")
_model = None
_tokenizer = None
def _load_model():
global _model, _tokenizer
if _model is not None:
return True
if not HF_MODEL_ID:
return False
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
logger.info(f"Loading model: {HF_MODEL_ID}")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
_tokenizer = AutoTokenizer.from_pretrained(
HF_MODEL_ID,
token=HF_TOKEN or None,
)
_tokenizer.pad_token = _tokenizer.eos_token
_model = AutoModelForCausalLM.from_pretrained(
HF_MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
token=HF_TOKEN or None,
low_cpu_mem_usage=True,
)
_model.eval()
logger.info("✅ Model loaded successfully")
return True
except Exception as e:
logger.error(f"Model load failed: {e}")
return False
def _llm_analyze(solidity_code: str) -> Optional[str]:
if not _load_model():
return None
try:
import torch
prompt = (
f"<|user|>\nAnalyze this Solidity contract for security vulnerabilities:\n"
f"```solidity\n{solidity_code[:1500]}\n```\n<|assistant|>\n"
)
inputs = _tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(_model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = _model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
do_sample=True,
pad_token_id=_tokenizer.eos_token_id,
)
generated = outputs[0][inputs["input_ids"].shape[1]:]
return _tokenizer.decode(generated, skip_special_tokens=True).strip()
except Exception as e:
logger.error(f"Inference error: {e}")
return None
def analyze_contract(solidity_code: str) -> Dict[str, Any]:
pattern_result = analyze_with_patterns(solidity_code)
llm_text = _llm_analyze(solidity_code)
return {
**pattern_result,
"llm_analysis": llm_text,
"analysis_type": "llm+pattern" if llm_text else "pattern",
}
def model_status() -> Dict[str, Any]:
return {
"model_loaded": _model is not None,
"model_id": HF_MODEL_ID or "not configured",
}