prompt-optimizer-lora / phase3_test_model.py
purvbhor-10's picture
Upload 4 files
47338bf verified
# ============================================================
# PHASE 3 β€” Testing the trained model
# Run: python phase3_test_model.py
# Make sure lora-adapter/ folder is in the same directory
# ============================================================
import torch
import json
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
print("=" * 60)
print(" PHASE 3: Testing Prompt Optimizer")
print("=" * 60)
BASE_MODEL = "google/gemma-2b"
ADAPTER_PATH = "./lora-adapter"
# ── Check adapter exists ────────────────────────────────────
if not os.path.exists(ADAPTER_PATH):
print("\n❌ lora-adapter/ folder not found!")
print(" β†’ Download it from Colab after training finishes.")
exit(1)
# ── Load model ──────────────────────────────────────────────
print("\n[1/3] Loading model + LoRA adapter (may take 1–2 min)...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto"
)
model = PeftModel.from_pretrained(model, ADAPTER_PATH)
model.eval()
print(" βœ… Model loaded")
# ── Inference function ──────────────────────────────────────
def improve_prompt(weak_prompt: str, max_new_tokens: int = 250) -> str:
input_text = f"### Weak Prompt:\n{weak_prompt}\n\n### Improved Prompt:\n"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id
)
full_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Return only the improved part
if "### Improved Prompt:" in full_text:
return full_text.split("### Improved Prompt:")[-1].strip()
return full_text.strip()
# ── 15 test prompts (diverse topics) ───────────────────────
test_prompts = [
"write about dogs",
"explain machine learning",
"help me code",
"tell me about space",
"make a diet plan",
"write an email",
"summarize history",
"explain climate change",
"how to learn python",
"write a story",
"explain blockchain",
"give me recipe ideas",
"help with my resume",
"explain quantum computing",
"plan a road trip",
]
print(f"\n[2/3] Testing on {len(test_prompts)} prompts...")
print("-" * 60)
results = []
for i, prompt in enumerate(test_prompts, 1):
print(f"\n[{i}/{len(test_prompts)}] Weak: {prompt}")
improved = improve_prompt(prompt)
print(f" Improved:\n{improved}")
print("-" * 60)
results.append({"weak": prompt, "improved": improved})
# ── Save results ────────────────────────────────────────────
os.makedirs("data", exist_ok=True)
with open("data/test_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print("\n[3/3] Results saved")
print("\n" + "=" * 60)
print(" βœ… PHASE 3 COMPLETE!")
print(" πŸ“„ Results β†’ data/test_results.json")
print(" ➑️ Next: Open dashboard.html to visualize results")
print("=" * 60)