CSTest / app.py
MrA7A's picture
Update app.py
1414d80 verified
import os, sys, json, torch, logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import login
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
logger.warning("⚠️ HF_TOKEN not set.")
try:
login(token=HF_TOKEN, add_to_git_credential=False)
logger.info("✅ HF Hub login successful")
except Exception as e:
logger.error("❌ HF login failed", exc_info=True)
sys.exit(1)
BASE_MODEL = "Qwen/Qwen1.5-0.5B-Chat"
LORA_REPO = "MrA7A/pentest-lora-qwen1.5-continual"
logger.info("Loading tokenizer & base model...")
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_auth_token=HF_TOKEN)
tok.pad_token = tok.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu", use_auth_token=HF_TOKEN
)
logger.info("Applying LoRA adapter...")
model = PeftModel.from_pretrained(base_model, LORA_REPO, use_auth_token=HF_TOKEN)
model.eval()
def gen_response(prompt: str):
safe = prompt.replace('"', '').replace("\\", "")
text = f"""<|im_start|>system\nYou are a cybersecurity expert.<|im_end|>\n<|im_start|>user\n{safe}<|im_end|>\n<|im_start|>assistant\n"""
inputs = tok(text, return_tensors="pt")
with torch.no_grad():
out = model.generate(
**inputs, max_new_tokens=512, do_sample=True, temperature=0.7,
top_p=0.9, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id
)
dec = tok.decode(out[0], skip_special_tokens=False)
if "<|im_start|>assistant" in dec:
return dec.split("<|im_start|>assistant\n",1)[1].split("<|im_end|>")[0].strip()
return dec
app = FastAPI()
class Prompt(BaseModel):
prompt: str
@app.get("/ping")
def ping():
return {"status": "ok"}
@app.post("/generate")
def generate(body: Prompt):
if not body.prompt:
raise HTTPException(status_code=400, detail="Empty prompt")
resp = gen_response(body.prompt)
return {"response": resp}