| |
| """ |
| Inference script for the Government Service Application Assistant Environment. |
| |
| Uses OpenAI client with Groq API for stateful session management. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import requests |
| from typing import Dict, Any |
|
|
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile") |
| HF_TOKEN = os.getenv("HF_TOKEN", "") |
| ENV_URL = os.getenv("ENV_URL", "https://dharunkkk-gov-env.hf.space") |
|
|
| def check_prereqs(): |
| """Check if required environment variables are set.""" |
| if not HF_TOKEN: |
| print("[ERROR] HF_TOKEN not set. Please set your Groq API key as HF_TOKEN env var.") |
| sys.exit(1) |
| |
| if not ENV_URL: |
| print("[ERROR] ENV_URL not set. Please set your Space URL as ENV_URL env var.") |
| sys.exit(1) |
|
|
| def check_space_health(): |
| """Verify the HF Space is reachable.""" |
| try: |
| response = requests.get(f"{ENV_URL}/health", timeout=30) |
| if response.status_code == 200: |
| print(f"[INFO] Space reachable: {ENV_URL}") |
| return True |
| else: |
| print(f"[ERROR] Space returned status {response.status_code}") |
| return False |
| except requests.exceptions.Timeout: |
| print(f"[ERROR] Connection timeout to {ENV_URL}") |
| return False |
| except Exception as e: |
| print(f"[ERROR] Cannot connect to {ENV_URL}: {e}") |
| return False |
|
|
| def get_tasks(base_url: str) -> list: |
| """Get available tasks via HTTP.""" |
| try: |
| response = requests.get(f"{base_url}/tasks", timeout=30) |
| response.raise_for_status() |
| return response.json().get("tasks", []) |
| except Exception as e: |
| print(f"[ERROR] Failed to fetch tasks: {e}") |
| return [] |
|
|
| from openai import OpenAI |
|
|
| llm_client = OpenAI( |
| base_url=API_BASE_URL, |
| api_key=HF_TOKEN |
| ) |
|
|
| try: |
| from client import GovEnv |
| from models import GovAction |
| except ImportError as e: |
| print(f"[ERROR] Import failed: {e}") |
| sys.exit(1) |
|
|
| def run_inference(task_id: str) -> None: |
| """Run inference for a specific task.""" |
| check_prereqs() |
| |
| if not check_space_health(): |
| print(f"[ERROR] Space not reachable at {ENV_URL}") |
| sys.exit(1) |
| |
| tasks = get_tasks(ENV_URL) |
| if not tasks: |
| print("[ERROR] Could not fetch tasks from space") |
| sys.exit(1) |
| |
| task_info = None |
| for task in tasks: |
| if task["task_id"] == task_id: |
| task_info = task |
| break |
| |
| if not task_info: |
| print(f"[ERROR] Task {task_id} not found. Available: {[t['task_id'] for t in tasks]}") |
| sys.exit(1) |
| |
| print(f"[START] task={task_id} env=gov_env model={MODEL_NAME}") |
| |
| try: |
| with GovEnv(base_url=ENV_URL).sync() as env: |
| result = env.reset() |
| current_obs = result.observation |
| step_count = 0 |
| total_reward = 0.0 |
| done = False |
| rewards = [] |
| |
| while not done and step_count < 10: |
| step_count += 1 |
| |
| prompt = create_prompt(task_info, current_obs, step_count) |
| |
| try: |
| completion = llm_client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[{"role": "user", "content": prompt}], |
| temperature=0.1, |
| max_tokens=500 |
| ) |
| action_text = completion.choices[0].message.content |
| |
| action_data = parse_action(action_text) |
| |
| gov_action = GovAction(**action_data) |
| step_result = env.step(gov_action) |
| |
| current_obs = step_result.observation |
| reward = step_result.reward if step_result.reward is not None else 0.0 |
| done = step_result.done |
| |
| total_reward += reward |
| rewards.append(reward) |
| |
| print(f"[STEP] step={step_count} action={json.dumps(action_data)} reward={reward:.2f} done={str(done).lower()} error=null") |
| |
| if done: |
| break |
| |
| except Exception as e: |
| print(f"[STEP] step={step_count} action=error reward=0.00 done=false error={str(e)}") |
| break |
| |
| success = done and total_reward > 0.5 |
| rewards_str = ",".join([f"{r:.2f}" for r in rewards]) if rewards else "" |
| print(f"[END] success={str(success).lower()} steps={step_count} score={total_reward:.2f} rewards={rewards_str}") |
| |
| except Exception as e: |
| print(f"[ERROR] Inference failed: {e}") |
| sys.exit(1) |
|
|
| def create_prompt(task_info: Dict[str, Any], observation, step: int) -> str: |
| """Create a prompt for the model based on current state.""" |
| task_desc = task_info["description"] |
| service_type = task_info.get("service_type", "") |
| difficulty = task_info.get("expected_difficulty", "") |
| |
| obs_dict = observation.model_dump() if hasattr(observation, 'model_dump') else {} |
| |
| prompt = f"""You are an AI assistant helping users with Indian government service applications. |
| |
| TASK: {task_desc} |
| SERVICE TYPE: {service_type} |
| DIFFICULTY: {difficulty} |
| CURRENT STEP: {step} |
| |
| CURRENT STATE: |
| - Stage: {obs_dict.get('current_stage', 'unknown')} |
| - Service: {obs_dict.get('service_type', 'none')} |
| - Message: {obs_dict.get('message', '')} |
| |
| REQUIRED DOCUMENTS: |
| """ |
| |
| req_docs = obs_dict.get('required_documents', []) |
| if req_docs: |
| for doc in req_docs: |
| prompt += f"- {doc.get('type', '')}: {doc.get('description', '')}\n" |
| else: |
| prompt += "None specified yet\n" |
| |
| submitted_docs = obs_dict.get('submitted_documents', []) |
| if submitted_docs: |
| prompt += "\nSUBMITTED DOCUMENTS:\n" |
| for doc in submitted_docs: |
| prompt += f"- {doc.get('type', '')}: {doc.get('details', '')}\n" |
| |
| validation_results = obs_dict.get('validation_results') |
| if validation_results: |
| prompt += f"\nVALIDATION RESULTS:\n" |
| prompt += f"- Complete: {validation_results.get('is_complete', False)}\n" |
| prompt += f"- Valid: {validation_results.get('is_valid', False)}\n" |
| |
| missing = validation_results.get('missing_documents', []) |
| if missing: |
| prompt += f"- Missing: {', '.join(missing)}\n" |
| |
| invalid = validation_results.get('invalid_documents', []) |
| if invalid: |
| prompt += f"- Invalid: {len(invalid)} documents have issues\n" |
| |
| corrections = obs_dict.get('correction_suggestions', []) |
| if corrections: |
| prompt += "\nCORRECTION SUGGESTIONS:\n" |
| for corr in corrections: |
| prompt += f"- {corr.get('suggested_action', '')}\n" |
| |
| prompt += """ |
| Respond with valid JSON action_types: |
| 1. select_service - provide service_type |
| 2. list_required_documents - no extra fields |
| 3. validate_documents - provide documents array |
| 4. suggest_corrections - no extra fields |
| 5. submit_application - no extra fields |
| |
| Example: {"action_type": "select_service", "service_type": "passport_new"} |
| """ |
| |
| return prompt |
|
|
| def parse_action(action_text: str) -> Dict[str, Any]: |
| """Parse the model's response into an action.""" |
| try: |
| action_data = json.loads(action_text) |
| return action_data |
| except json.JSONDecodeError: |
| action_data = {"message": action_text} |
| |
| text_lower = action_text.lower() |
| if "select_service" in text_lower: |
| action_data["action_type"] = "select_service" |
| if "passport" in text_lower: |
| action_data["service_type"] = "passport_new" |
| elif "list_required" in text_lower or "required_documents" in text_lower: |
| action_data["action_type"] = "list_required_documents" |
| elif "validate" in text_lower: |
| action_data["action_type"] = "validate_documents" |
| elif "suggest" in text_lower or "correction" in text_lower: |
| action_data["action_type"] = "suggest_corrections" |
| elif "submit" in text_lower: |
| action_data["action_type"] = "submit_application" |
| |
| return action_data |
|
|
| def main(): |
| """Main entry point.""" |
| if len(sys.argv) != 2: |
| print("Usage: python inference.py <task_id>") |
| sys.exit(1) |
| |
| task_id = sys.argv[1] |
| run_inference(task_id) |
|
|
| if __name__ == "__main__": |
| main() |