citeGuardian / inference.py
zrypton's picture
Upload folder using huggingface_hub
5d50cde verified
Raw
History Blame Contribute Delete
8.33 kB
"""
Inference Script — CiteGuardian
===================================
MANDATORY
- Before submitting, ensure the following variables are defined in your environment configuration:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
- Defaults are set only for API_BASE_URL and MODEL_NAME:
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
STDOUT FORMAT
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
"""
import asyncio
import json
import os
import textwrap
from typing import List, Optional
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from citeGuardian import CiteguardianAction, CiteguardianEnv
# Read environment variables with defaults where required
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("HF_TOKEN environment variable is required")
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
ENV_URL = os.getenv("ENV_URL") or os.getenv("SERVER_URL")
TASK_NAME = os.getenv("CITEGUARDIAN_TASK", "audit")
BENCHMARK = os.getenv("CITEGUARDIAN_BENCHMARK", "citeGuardian")
MAX_STEPS = 30
TEMPERATURE = 0.2
MAX_TOKENS = 256
SUCCESS_SCORE_THRESHOLD = 0.5
# Initialize OpenAI client
client_llm = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
SYSTEM_PROMPT = textwrap.dedent("""
You are an expert academic peer-reviewer auditing a research paper.
Respond with EXACTLY one JSON object per turn.
TOOLS:
{"action_type": "GO_TO", "section_name": "<name>"}
{"action_type": "SCAN_CITATIONS"}
{"action_type": "COMPARE_VALUES", "val1": "<v1>", "val2": "<v2>"}
{"action_type": "FLAG_ERROR", "error_type": "<STRUCTURAL_ERROR|ORPHAN_CITATION|LOGICAL_INCONSISTENCY>", "text_snippet": "<snippet>"}
{"action_type": "SUBMIT"}
MANDATORY SECTIONS: Abstract, Introduction, Methods, Results, Discussion, References
Task A: If any mandatory section missing in available_sections → FLAG_ERROR STRUCTURAL_ERROR text_snippet=<section name> → SUBMIT.
Task B: SCAN_CITATIONS in each section, cross-check with References → FLAG_ERROR ORPHAN_CITATION for mismatches → SUBMIT.
Task C: GO_TO Methods and Results, COMPARE_VALUES on numeric claims → FLAG_ERROR LOGICAL_INCONSISTENCY if conflict → SUBMIT.
RULES: Only flag when certain. Never flag same error twice. Reply ONLY with valid JSON, no markdown.
""").strip()
def _obs_to_prompt(step: int, result) -> str:
o = result.observation if hasattr(result, "observation") else result
available = o.metadata.get("available_sections", []) if getattr(o, "metadata", None) else []
missing = [s for s in ["Abstract", "Introduction", "Methods", "Results", "Discussion", "References"]
if s not in available]
return textwrap.dedent(f"""
Step: {step} | Task: {getattr(o, 'task_level', '?')}
Section: {o.metadata.get('current_section', '?') if getattr(o, 'metadata', None) else '?'}
Available: {available} | Missing: {missing or 'none'}
Visited: {o.metadata.get('visited_sections', []) if getattr(o, 'metadata', None) else []}
Citations here: {o.metadata.get('citation_markers_in_view', []) if getattr(o, 'metadata', None) else []}
All citations: {o.metadata.get('all_paper_citations', []) if getattr(o, 'metadata', None) else []}
Flags raised: {o.metadata.get('flags_raised', 0) if getattr(o, 'metadata', None) else 0}
Tool result: {json.dumps(getattr(o, 'tool_result', None))}
Message: {getattr(o, 'message', '')}
View: {getattr(o, 'current_view', '')[:600]}
Log: {json.dumps(o.audit_log[-3:] if getattr(o, 'audit_log', None) else [])}
Next action?
""").strip()
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}", flush=True)
def log_end(success: bool, steps: int, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
def get_model_action(step: int, result, messages: List[ChatCompletionMessageParam]) -> CiteguardianAction:
messages.append({"role": "user", "content": _obs_to_prompt(step, result)})
try:
completion = client_llm.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
raw = (completion.choices[0].message.content or "").strip()
messages.append({"role": "assistant", "content": raw})
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
return CiteguardianAction(**json.loads(raw))
except Exception as exc:
print(f"[DEBUG] Model/parse error at step {step}: {exc}", flush=True)
return CiteguardianAction(action_type="SUBMIT")
async def main() -> None:
rewards: List[float] = []
steps_taken = 0
success = False
env = None
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
try:
if ENV_URL:
env = CiteguardianEnv(base_url=ENV_URL)
await env.connect()
else:
env = await CiteguardianEnv.from_docker_image(IMAGE_NAME)
messages: List[ChatCompletionMessageParam] = [{"role": "system", "content": SYSTEM_PROMPT}]
result = await env.reset()
last_reward = 0.0
for step in range(1, MAX_STEPS + 1):
done = getattr(result, "done", False)
if not done:
obs = result.observation if hasattr(result, "observation") else result
done = getattr(obs, "done", False)
if done and step > 1:
break
action = get_model_action(step, result, messages)
action_str = action.model_dump_json(exclude_none=True)
try:
result = await env.step(action)
obs = result.observation if hasattr(result, "observation") else result
reward = float(getattr(result, "reward", None) or getattr(obs, "reward", 0.0))
done = getattr(result, "done", False) or getattr(obs, "done", False)
last_reward = float(getattr(obs, "reward", reward))
error_msg = None
except Exception as step_exc:
print(f"[DEBUG] Step {step} error: {step_exc}", flush=True)
reward = 0.0
done = True
error_msg = str(step_exc)
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward, done=done, error=error_msg)
if done:
break
score = min(max(last_reward, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as exc:
print(f"[DEBUG] Episode error: {exc}", flush=True)
import traceback
traceback.print_exc()
finally:
if env is not None:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", flush=True)
log_end(success=success, steps=steps_taken, rewards=rewards)
if __name__ == "__main__":
asyncio.run(main())