|
|
import os, sys, json, torch, logging |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
from huggingface_hub import login |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
if not HF_TOKEN: |
|
|
logger.warning("⚠️ HF_TOKEN not set.") |
|
|
|
|
|
try: |
|
|
login(token=HF_TOKEN, add_to_git_credential=False) |
|
|
logger.info("✅ HF Hub login successful") |
|
|
except Exception as e: |
|
|
logger.error("❌ HF login failed", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
BASE_MODEL = "Qwen/Qwen1.5-0.5B-Chat" |
|
|
LORA_REPO = "MrA7A/pentest-lora-qwen1.5-continual" |
|
|
|
|
|
logger.info("Loading tokenizer & base model...") |
|
|
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_auth_token=HF_TOKEN) |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu", use_auth_token=HF_TOKEN |
|
|
) |
|
|
|
|
|
logger.info("Applying LoRA adapter...") |
|
|
model = PeftModel.from_pretrained(base_model, LORA_REPO, use_auth_token=HF_TOKEN) |
|
|
model.eval() |
|
|
|
|
|
def gen_response(prompt: str): |
|
|
safe = prompt.replace('"', '').replace("\\", "") |
|
|
text = f"""<|im_start|>system\nYou are a cybersecurity expert.<|im_end|>\n<|im_start|>user\n{safe}<|im_end|>\n<|im_start|>assistant\n""" |
|
|
inputs = tok(text, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, |
|
|
top_p=0.9, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id |
|
|
) |
|
|
dec = tok.decode(out[0], skip_special_tokens=False) |
|
|
if "<|im_start|>assistant" in dec: |
|
|
return dec.split("<|im_start|>assistant\n",1)[1].split("<|im_end|>")[0].strip() |
|
|
return dec |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class Prompt(BaseModel): |
|
|
prompt: str |
|
|
|
|
|
@app.get("/ping") |
|
|
def ping(): |
|
|
return {"status": "ok"} |
|
|
|
|
|
@app.post("/generate") |
|
|
def generate(body: Prompt): |
|
|
if not body.prompt: |
|
|
raise HTTPException(status_code=400, detail="Empty prompt") |
|
|
resp = gen_response(body.prompt) |
|
|
return {"response": resp} |
|
|
|