import os import logging from typing import Optional, Dict, Any from app.models.patterns import analyze_with_patterns logger = logging.getLogger(__name__) HF_MODEL_ID = os.getenv("HF_MODEL_ID", "") HF_TOKEN = os.getenv("HF_TOKEN", "") _model = None _tokenizer = None def _load_model(): global _model, _tokenizer if _model is not None: return True if not HF_MODEL_ID: return False try: import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig logger.info(f"Loading model: {HF_MODEL_ID}") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) _tokenizer = AutoTokenizer.from_pretrained( HF_MODEL_ID, token=HF_TOKEN or None, ) _tokenizer.pad_token = _tokenizer.eos_token _model = AutoModelForCausalLM.from_pretrained( HF_MODEL_ID, quantization_config=bnb_config, device_map="auto", token=HF_TOKEN or None, low_cpu_mem_usage=True, ) _model.eval() logger.info("✅ Model loaded successfully") return True except Exception as e: logger.error(f"Model load failed: {e}") return False def _llm_analyze(solidity_code: str) -> Optional[str]: if not _load_model(): return None try: import torch prompt = ( f"<|user|>\nAnalyze this Solidity contract for security vulnerabilities:\n" f"```solidity\n{solidity_code[:1500]}\n```\n<|assistant|>\n" ) inputs = _tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(_model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = _model.generate( **inputs, max_new_tokens=300, temperature=0.7, do_sample=True, pad_token_id=_tokenizer.eos_token_id, ) generated = outputs[0][inputs["input_ids"].shape[1]:] return _tokenizer.decode(generated, skip_special_tokens=True).strip() except Exception as e: logger.error(f"Inference error: {e}") return None def analyze_contract(solidity_code: str) -> Dict[str, Any]: pattern_result = analyze_with_patterns(solidity_code) llm_text = _llm_analyze(solidity_code) return { **pattern_result, "llm_analysis": llm_text, "analysis_type": "llm+pattern" if llm_text else "pattern", } def model_status() -> Dict[str, Any]: return { "model_loaded": _model is not None, "model_id": HF_MODEL_ID or "not configured", }