File size: 2,490 Bytes
bc20ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import torch
import httpx
from transformers import AutoTokenizer, AutoModelForCausalLM

ENV_URL = "http://localhost:7860"
MODEL_NAME = "Qwen/Qwen2.5-Coder-0.5B-Instruct"

def test_logic():
    print(f"πŸš€ Starting Logic Smoke Test...")
    
    # 1. Check if server is up
    try:
        httpx.get(f"{ENV_URL}/health")
        print("βœ… Environment server is alive.")
    except:
        print("❌ Error: Server not found. Run 'python3 -m uvicorn server.main:app --port 7860' first.")
        return

    # 2. Load model (CPU only to save disk/temp space)
    print(f"πŸ“¦ Loading model {MODEL_NAME} on CPU...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to("cpu")
    
    # 3. Get a task
    resp = httpx.post(f"{ENV_URL}/reset", json={"task_id": "easy_syntax_fix"})
    obs = resp.json()["observation"]
    print(f"πŸ“ Task Loaded: {obs['task_description'][:100]}...")

    # 4. Ask Model for a fix
    prompt = f"Fix this SQL query:\n{obs['original_query']}\nProvide ONLY the fixed SQL query, no other text."
    inputs = tokenizer(prompt, return_tensors="pt")
    
    print("πŸ€– AI is thinking...")
    outputs = model.generate(
        inputs.input_ids, 
        max_new_tokens=100, 
        pad_token_id=tokenizer.eos_token_id
    )
    # Decode only the NEW tokens
    fix = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
    
    if not fix:
        fix = "SELECT * FROM users;" # Fallback for test if AI is silent
        print("⚠️ AI was silent, using fallback query for connection test.")
    else:
        print(f"✨ AI Proposed Fix: {fix}")

    # 5. Get Reward
    print("🎯 Sending to environment for grading...")
    step_resp = httpx.post(
        f"{ENV_URL}/step", 
        json={"action": {"action_type": "submit_query", "query": fix}}
    )
    
    if step_resp.status_code != 200:
        print(f"❌ Server Error {step_resp.status_code}: {step_resp.text}")
        return

    result = step_resp.json()
    
    print(f"πŸ† TEST RESULT:")
    print(f"   - Reward Score: {result.get('reward', 'MISSING')}")
    print(f"   - Done: {result.get('done', 'MISSING')}")
    
    if result.get('reward') and result['reward'] >= 0.5:
        print("   - Status: Success! System is fully operational.")
    else:
        print("   - Status: Connection test passed (Reward received).")

if __name__ == "__main__":
    test_logic()