Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Inference script for Doc Quality Environment. | |
| This script runs an LLM agent against the documentation quality assessment environment. | |
| It demonstrates how an AI agent can evaluate and improve technical documentation. | |
| Environment variables: | |
| API_BASE_URL: LLM API endpoint (default: https://router.huggingface.co/v1) | |
| MODEL_NAME: Model identifier (default: Qwen/Qwen2.5-7B-Instruct) | |
| HF_TOKEN: Hugging Face API token (required) | |
| """ | |
| import os | |
| import json | |
| import textwrap | |
| from typing import Optional, List | |
| from openai import OpenAI | |
| from doc_quality_env.server.doc_quality_env_environment import DocQualityEnvironment | |
| from doc_quality_env.models import DocQualityAction | |
| # Configuration - read from environment variables | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN is None: | |
| raise ValueError("HF_TOKEN environment variable is required") | |
| # Initialize OpenAI client with configurable API endpoint | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| # Task configuration | |
| TASKS = ["easy_api_doc", "medium_api_doc", "hard_guide_review"] | |
| MAX_STEPS_PER_TASK = 10 | |
| TEMPERATURE = 0.7 | |
| MAX_TOKENS = 200 | |
| def log_start(task: str, env: str, model: str) -> None: | |
| """Log the start of an episode.""" | |
| print(f"[START] task={task} env=doc_quality_env model={model}", flush=True) | |
| def log_step( | |
| step: int, action: str, reward: float, done: bool, error: Optional[str] | |
| ) -> None: | |
| """Log a step in the episode.""" | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| """Log the end of an episode.""" | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def call_llm(prompt: str) -> str: | |
| """Call the LLM to get the agent's next action.""" | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are an expert technical documentation reviewer. Provide clear, actionable feedback on documentation quality.", | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| raise RuntimeError(f"LLM call failed: {str(e)}") | |
| def parse_agent_response(response: str) -> tuple: | |
| """ | |
| Parse the LLM response into action components. | |
| Expected format: "ACTION_TYPE|CATEGORY|CONTENT" | |
| """ | |
| try: | |
| parts = response.split("|", 2) | |
| if len(parts) >= 3: | |
| return parts[0].strip(), parts[1].strip(), parts[2].strip() | |
| # Fallback parsing - try to extract from text | |
| response_lower = response.lower() | |
| if "identify" in response_lower or "issue" in response_lower: | |
| action_type = "identify_issue" | |
| elif ( | |
| "suggest" in response_lower | |
| or "improve" in response_lower | |
| or "fix" in response_lower | |
| ): | |
| action_type = "suggest_improvement" | |
| elif ( | |
| "rate" in response_lower | |
| or "score" in response_lower | |
| or "quality" in response_lower | |
| ): | |
| action_type = "rate_quality" | |
| else: | |
| action_type = "identify_issue" | |
| return action_type, "clarity", response | |
| except: | |
| return "identify_issue", "clarity", response | |
| def build_prompt(step: int, obs) -> str: | |
| """Build the prompt for the LLM based on current observation.""" | |
| doc_preview = ( | |
| obs.current_doc[:500] + "..." if len(obs.current_doc) > 500 else obs.current_doc | |
| ) | |
| issues_str = ( | |
| "\n".join(f"- {issue}" for issue in obs.issues_identified[-3:]) | |
| if obs.issues_identified | |
| else "None yet" | |
| ) | |
| known_str = "\n".join(f"- {issue}" for issue in obs.known_issues[:3]) | |
| prompt = textwrap.dedent(f""" | |
| Task: {obs.task_name} | |
| Difficulty: {obs.task_difficulty} | |
| Step: {step}/{obs.max_steps} | |
| Documentation Preview: | |
| {doc_preview} | |
| Issues You've Already Identified: | |
| {issues_str} | |
| Hints (Sample Known Issues): | |
| {known_str} | |
| Last Feedback: {obs.feedback} | |
| Your Options: | |
| 1. Identify another issue in the documentation (format: identify_issue|CATEGORY|DESCRIPTION) | |
| 2. Suggest how to improve it (format: suggest_improvement|CATEGORY|SUGGESTION) | |
| 3. Rate the overall quality (format: rate_quality|overall|SCORE_0_TO_1) | |
| Respond with ONE action in the format above. Be specific and actionable. | |
| """).strip() | |
| return prompt | |
| def run_task_episode(env: DocQualityEnvironment, task_key: str) -> tuple: | |
| """Run a single episode on a task.""" | |
| obs = env.reset() | |
| task_name = obs.task_name | |
| step_count = 0 | |
| total_reward = 0.0 | |
| all_rewards: List[float] = [] | |
| success = False | |
| log_start(task_key, "doc_quality_env", MODEL_NAME) | |
| try: | |
| for step in range(1, MAX_STEPS_PER_TASK + 1): | |
| # Get agent action from LLM | |
| prompt = build_prompt(step, obs) | |
| llm_response = call_llm(prompt) | |
| # Parse the response | |
| action_type, category, content = parse_agent_response(llm_response) | |
| # Create action | |
| action = DocQualityAction( | |
| action_type=action_type, content=content, issue_category=category | |
| ) | |
| # Execute action | |
| obs = env.step(action) | |
| step_count += 1 | |
| reward = obs.reward | |
| all_rewards.append(reward) | |
| total_reward += reward | |
| # Log the step | |
| action_str = ( | |
| f"{action_type}('{content[:30]}'...)" | |
| if len(content) > 30 | |
| else f"{action_type}('{content}')" | |
| ) | |
| log_step(step, action_str, reward, obs.done, None) | |
| if obs.done: | |
| success = True | |
| break | |
| # Final score based on issues found | |
| final_score = min( | |
| 1.0, len(obs.issues_identified) / max(len(obs.known_issues), 1) | |
| ) | |
| except Exception as e: | |
| final_score = total_reward / max(step_count, 1) if step_count > 0 else 0.0 | |
| log_step(step_count + 1, f"error", 0.0, True, str(e)) | |
| log_end(success, step_count, final_score, all_rewards) | |
| return final_score, success, all_rewards | |
| def main(): | |
| """Run the inference script on all tasks.""" | |
| print("=" * 60, flush=True) | |
| print("Doc Quality Environment - Inference Script", flush=True) | |
| print(f"Model: {MODEL_NAME}", flush=True) | |
| print(f"API: {API_BASE_URL}", flush=True) | |
| print("=" * 60, flush=True) | |
| print("", flush=True) | |
| task_scores = [] | |
| for task_key in TASKS: | |
| print(f"Running task: {task_key}", flush=True) | |
| env = DocQualityEnvironment() | |
| try: | |
| # Reset to initialize | |
| obs = env.reset() | |
| score, success, rewards = run_task_episode(env, task_key) | |
| task_scores.append(score) | |
| except Exception as e: | |
| print(f"[ERROR] Task {task_key} failed: {e}", flush=True) | |
| task_scores.append(0.0) | |
| finally: | |
| env.close() | |
| print("", flush=True) | |
| # Summary | |
| avg_score = sum(task_scores) / len(task_scores) if task_scores else 0.0 | |
| print("=" * 60, flush=True) | |
| print(f"Summary: Average Score = {avg_score:.2f}", flush=True) | |
| print("=" * 60, flush=True) | |
| if __name__ == "__main__": | |
| main() | |