| from models import ContractValidationAction
|
| from client import ContractValidationEnv
|
| import os
|
| import json
|
| import textwrap
|
| import asyncio
|
| from typing import List, Optional
|
|
|
| from openai import OpenAI
|
| from dotenv import load_dotenv
|
|
|
|
|
| load_dotenv()
|
|
|
|
|
|
|
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
| BENCHMARK = "contract_validation"
|
| MAX_STEPS = 15
|
|
|
|
|
|
|
|
|
| def log_start(task: str) -> None:
|
| print(f"[START] task={task}", flush=True)
|
|
|
|
|
| def log_step(step: int, reward: float) -> None:
|
|
|
| clamped_reward = max(0, round(reward, 2))
|
| print(f"[STEP] step={step} reward={clamped_reward}", flush=True)
|
|
|
|
|
| def log_end(task: str, score: float, steps: int) -> None:
|
|
|
| final_score = max(0.01, min(0.99, round(score, 2)))
|
| print(f"[END] task={task} score={final_score} steps={steps}", flush=True)
|
|
|
|
|
| async def run_task(client: OpenAI, task_level: str):
|
|
|
| space_url = "https://envarchitects-contract-validation-env.hf.space"
|
| env = ContractValidationEnv(base_url=space_url)
|
|
|
| try:
|
| result = await env.reset(task_level=task_level)
|
| obs = result.observation
|
| done = False
|
|
|
|
|
| log_start(task=task_level)
|
|
|
| while not done and obs.step_count < MAX_STEPS:
|
| system_prompt = textwrap.dedent("""
|
| You are a precise legal AI. Review the clauses and output valid JSON.
|
| Your JSON must match exactly:
|
| {
|
| "thoughts": "your reasoning",
|
| "clause_id": 1,
|
| "risk_type": "liability",
|
| "submit_final": false
|
| }
|
| Valid risk types: liability, payment, termination, confidentiality, compliance, none.
|
| """).strip()
|
|
|
| user_prompt = textwrap.dedent(f"""
|
| Current Clauses: {json.dumps(obs.contract_clauses)}
|
| Risks You Have ALREADY Flagged: {json.dumps(obs.flagged_risks)}
|
|
|
| Instructions:
|
| 1. Identify any unflagged risks in the Current Clauses.
|
| 2. If there is a risk you haven't flagged yet, output its clause_id and risk_type.
|
| 3. DO NOT repeat an action. If a clause is already in your "ALREADY Flagged" list, leave it alone.
|
| 4. CRITICAL: If you have found all the risks (or if the remaining clauses are perfectly safe), you MUST end the review by setting "submit_final": true, "clause_id": 0, and "risk_type": "none".
|
| """).strip()
|
|
|
| try:
|
| response = client.chat.completions.create(
|
| model=MODEL_NAME,
|
| messages=[
|
| {"role": "system", "content": system_prompt},
|
| {"role": "user", "content": user_prompt}
|
| ],
|
| response_format={"type": "json_object"},
|
| temperature=0.1
|
| )
|
|
|
| raw_response = response.choices[0].message.content
|
| parsed = json.loads(raw_response)
|
|
|
| clause_id = int(parsed.get("clause_id", 0))
|
| risk_type = str(parsed.get("risk_type", "none"))
|
| submit_final = bool(parsed.get("submit_final", False))
|
|
|
| action = ContractValidationAction(
|
| clause_id=clause_id,
|
| risk_type=risk_type,
|
| submit_final=submit_final,
|
| explanation=parsed.get("thoughts", "")
|
| )
|
|
|
| except Exception as e:
|
|
|
| action = ContractValidationAction(
|
| clause_id=0, risk_type="none", submit_final=False)
|
|
|
| result = await env.step(action)
|
| obs = result.observation
|
|
|
| step_reward = result.reward if result.reward is not None else 0.0
|
| done = result.done
|
|
|
|
|
| log_step(step=obs.step_count, reward=step_reward)
|
|
|
| score = obs.info.get("score", 0.0)
|
|
|
|
|
| log_end(task=task_level, score=score, steps=obs.step_count)
|
|
|
| finally:
|
| try:
|
| await env.close()
|
| except Exception:
|
| pass
|
|
|
|
|
| async def main():
|
| if not HF_TOKEN:
|
| print("CRITICAL WARNING: HF_TOKEN is missing! Make sure your .env file is set up correctly.")
|
| return
|
|
|
| client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
|
|
|
|
| tasks = ["easy", "medium", "hard"]
|
| for t in tasks:
|
| await run_task(client, t)
|
|
|
|
|
| if __name__ == "__main__":
|
| asyncio.run(main())
|
|
|