AI-Trainer / inference.py
V1vex's picture
first-commit
625b444
import os, sys, requests, json, time
from openai import OpenAI
ENV_BASE_URL = "http://localhost:8000"
def play_round(round_number):
print(f"\n{'='*50}\n🏁 ROUND {round_number} STARTS!\n{'='*50}")
# 1. Get Task safely
try:
resp = requests.post(f"{ENV_BASE_URL}/reset", timeout=120).json()
except Exception as e:
sys.exit(f"🚨 Error: Cannot connect to OpenEnv Server. {e}")
task = ""
if isinstance(resp, dict):
if "observation" in resp and isinstance(resp["observation"], dict) and "echoed_message" in resp["observation"]:
task = resp["observation"]["echoed_message"]
elif "observation" in resp and isinstance(resp["observation"], dict) and "task_prompt" in resp["observation"]:
task = resp["observation"]["task_prompt"]
elif "echoed_message" in resp:
task = resp["echoed_message"]
else:
task = json.dumps(resp)
else:
task = str(resp)
print(f"🔥 JUDGE ASKS:\n{task}\n")
# 2. Qwen Agent API Call (Working on HF Router)
print("🤖 Agent is thinking (Using Qwen 2.5)...")
hf_token = os.environ.get("HF_TOKEN", "")
if not hf_token:
print("🚨 WARNING: HF_TOKEN is missing. Please set it in your environment variables.")
client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=hf_token
)
try:
completion = client.chat.completions.create(
model="Qwen/Qwen2.5-72B-Instruct",
messages=[
{"role": "system", "content": "You are a Python expert. Output ONLY valid Python code. No explanations, no markdown blocks like ```python."},
{"role": "user", "content": task}
],
)
agent_answer = completion.choices[0].message.content.replace("```python", "").replace("```", "").strip()
except Exception as e:
print(f"🚨 HF API Error: {e}")
agent_answer = "def generic_answer(): pass"
print(f"🗣️ AGENT'S ANSWER (Snippet):\n{agent_answer[:150]}...\n")
# 3. Submit to Server (Direct Payload)
print("⚖️ Submitting to Judge...")
payload = {"action": {"answer": agent_answer}}
try:
step_resp = requests.post(f"{ENV_BASE_URL}/step", json=payload, timeout=120)
if step_resp.status_code == 200:
result = step_resp.json()
score = result.get("observation", {}).get("reward", result.get("reward", 0.0))
else:
print(f"🚨 Server Error! Status: {step_resp.status_code}")
print(f"🚨 Details: {step_resp.text}")
score = 0.0
except Exception as e:
print(f"🚨 Server Communication Error: {e}")
score = 0.0
print(f"🏆 ROUND {round_number} SCORE : {score} / 1.0")
return score
def main():
print("🚀 [START] GEMMA AGENT vs OPENAI JUDGE")
total_score = 0
for i in range(1, 4):
total_score += play_round(i)
time.sleep(2)
print(f"\n🎉🎉 MATCH FINISHED! FINAL TOTAL SCORE: {total_score} / 3.0 🎉🎉")
if __name__ == "__main__":
main()