Spaces:
Sleeping
Sleeping
| """ | |
| AgentForge - Hugging Face Space Template | |
| This is a generic, reusable agent runner that reads configuration from environment variables. | |
| """ | |
| import os | |
| import json | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, List, Dict, Any | |
| from agents import Agent, AsyncOpenAI as AgentsAsyncOpenAI, OpenAIChatCompletionsModel, function_tool, Runner, SQLiteSession | |
| # ============================================ | |
| # Load Agent Configuration from Environment | |
| # ============================================ | |
| AGENT_CONFIG_STR = os.getenv("AGENT_CONFIG") | |
| if not AGENT_CONFIG_STR: | |
| raise ValueError("AGENT_CONFIG environment variable is required") | |
| # Parse the config - handle both nested and flat structures | |
| raw_config = json.loads(AGENT_CONFIG_STR) | |
| # Handle nested structure (from full API response) | |
| if isinstance(raw_config, dict): | |
| # Check if it's the full response structure with result.agent_build | |
| if "result" in raw_config and "agent_build" in raw_config.get("result", {}): | |
| AGENT_CONFIG = raw_config["result"]["agent_build"] | |
| # Check if it's nested under a different key | |
| elif "agent_build" in raw_config: | |
| AGENT_CONFIG = raw_config["agent_build"] | |
| # Otherwise assume it's already the flat agent_build structure | |
| else: | |
| AGENT_CONFIG = raw_config | |
| else: | |
| AGENT_CONFIG = raw_config | |
| # Validate that we have the required fields | |
| if not isinstance(AGENT_CONFIG, dict): | |
| raise ValueError("AGENT_CONFIG must be a dictionary") | |
| # API Keys from environment | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| GROK_API_KEY = os.getenv("GROK_API_KEY") | |
| # ============================================ | |
| # FastAPI App Setup | |
| # ============================================ | |
| app = FastAPI( | |
| title=f"{AGENT_CONFIG.get('name', 'Agent')} API", | |
| description=f"Deployed agent for {AGENT_CONFIG.get('business_context', {}).get('business_name', 'Business')}", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================ | |
| # Request/Response Models | |
| # ============================================ | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., description="User message to the agent") | |
| session_id: Optional[str] = Field(default="default", description="Session ID for conversation tracking") | |
| class ChatResponse(BaseModel): | |
| status: str | |
| agent_name: Optional[str] = None # May be missing in config | |
| user_message: str | |
| agent_response: str | |
| tools_available: List[str] | |
| timestamp: float | |
| # ============================================ | |
| # Dynamic Tool Recreation | |
| # ============================================ | |
| def recreate_tools_from_config(domain: str, business_name: str): | |
| """ | |
| Recreate tools based on domain. | |
| This mirrors the DynamicToolFactory logic from agent_architect.py | |
| """ | |
| if domain == "pharmacy": | |
| async def manage_prescription(action: str, prescription_id: str = None, patient_id: str = None, medication: str = None) -> dict: | |
| """Manage prescriptions - check, refill, or create""" | |
| from datetime import datetime | |
| return {"prescription_id": prescription_id or f"RX-{datetime.now().strftime('%Y%m%d%H%M')}", | |
| "action": action, "status": "Processed", "refills": 3} | |
| async def check_drug_inventory(medication_name: str) -> dict: | |
| """Check medication stock and expiry""" | |
| return {"medication": medication_name, "in_stock": True, "quantity": 250, "expiry": "2026-06-15"} | |
| async def get_patient_info(patient_id: str) -> dict: | |
| """Retrieve patient records and allergies""" | |
| return {"patient_id": patient_id, "allergies": ["Penicillin"], "medications": ["Metformin"]} | |
| def web_search(query: str) -> dict: | |
| """Perform a web search for current information""" | |
| return {"query": query, "results": "Web search functionality - integrate with real API"} | |
| return [manage_prescription, check_drug_inventory, get_patient_info, web_search] | |
| elif domain == "ecommerce": | |
| async def search_products(query: str, category: str = None) -> dict: | |
| """Search product catalog""" | |
| return {"query": query, "results": [{"id": "P001", "name": query, "price": 49.99, "stock": 50}]} | |
| async def track_order(order_id: str) -> dict: | |
| """Track order status and delivery""" | |
| return {"order_id": order_id, "status": "In Transit", "eta": "2025-11-20", "location": "Distribution Center"} | |
| async def manage_cart(action: str, product_id: str = None, quantity: int = 1) -> dict: | |
| """Add, remove, or view cart items""" | |
| return {"action": action, "product_id": product_id, "cart_total": 149.99, "items": 3} | |
| def web_search(query: str) -> dict: | |
| """Perform a web search for current information""" | |
| return {"query": query, "results": "Web search functionality"} | |
| return [search_products, track_order, manage_cart, web_search] | |
| elif domain == "weather": | |
| async def get_forecast(location: str, days: int = 7) -> dict: | |
| """Get weather forecast""" | |
| return {"location": location, "days": days, "forecast": [{"date": "2025-12-12", "high": 22, "low": 15, "condition": "partly cloudy"}]} | |
| async def severe_weather_alert(location: str) -> dict: | |
| """Check for severe weather alerts""" | |
| return {"location": location, "alerts": [], "severity": "none", "preparedness_tips": ["Normal precautions"]} | |
| async def historical_weather_comparison(location: str, date: str) -> dict: | |
| """Compare current weather to historical data""" | |
| return {"location": location, "date": date, "current_temp": 20, "historical_avg": 18, "difference": 2, "percentile": 65} | |
| def web_search(query: str) -> dict: | |
| """Perform a web search for current information""" | |
| return {"query": query, "results": "Web search functionality"} | |
| return [get_forecast, severe_weather_alert, historical_weather_comparison, web_search] | |
| # Add more domains as needed... | |
| else: # generic | |
| async def generate_analytics(metric: str, time_range: str) -> dict: | |
| """Generate business analytics""" | |
| return {"metric": metric, "time_range": time_range, "value": 12500, "trend": "+15%", "insights": f"{metric} growing"} | |
| async def send_notification(recipient: str, message: str, channel: str = "email") -> dict: | |
| """Send notifications""" | |
| return {"recipient": recipient, "message": message, "channel": channel, "status": "Sent"} | |
| def web_search(query: str) -> dict: | |
| """Perform a web search for current information""" | |
| return {"query": query, "results": "Web search functionality"} | |
| return [generate_analytics, send_notification, web_search] | |
| # ============================================ | |
| # Initialize Agent | |
| # ============================================ | |
| def initialize_agent(): | |
| """Initialize the agent with configuration from environment""" | |
| model = AGENT_CONFIG.get("model", "gpt-4o") | |
| # Select appropriate API key and client | |
| if "gemini" in model.lower(): | |
| api_key = GEMINI_API_KEY | |
| client = AgentsAsyncOpenAI(api_key=api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/") | |
| model_name = "gemini-2.0-flash-exp" | |
| elif "grok" in model.lower(): | |
| api_key = GROK_API_KEY | |
| client = AgentsAsyncOpenAI(api_key=api_key, base_url="https://api.x.ai/v1") | |
| model_name = "grok-beta" | |
| else: | |
| api_key = OPENAI_API_KEY | |
| client = AgentsAsyncOpenAI(api_key=api_key) | |
| model_name = "gpt-4o" | |
| if not api_key: | |
| raise ValueError(f"API key not found for model: {model}") | |
| MODEL = OpenAIChatCompletionsModel(model=model_name, openai_client=client) | |
| # Recreate tools - handle both nested and flat business_context | |
| business_context = AGENT_CONFIG.get("business_context", {}) | |
| if not isinstance(business_context, dict): | |
| business_context = {} | |
| domain = business_context.get("domain") or AGENT_CONFIG.get("domain", "generic") | |
| business_name = business_context.get("business_name") or AGENT_CONFIG.get("business_name", "Business") | |
| tools = recreate_tools_from_config(domain, business_name) | |
| # Get agent name - try multiple possible keys | |
| agent_name = AGENT_CONFIG.get("name") or AGENT_CONFIG.get("agent_name", "AI Agent") | |
| # Get instructions | |
| instructions = AGENT_CONFIG.get("instructions", "You are a helpful AI assistant.") | |
| # Create agent | |
| agent = Agent( | |
| name=agent_name, | |
| instructions=instructions, | |
| model=MODEL, | |
| tools=tools | |
| ) | |
| return agent, tools | |
| # Initialize agent on startup | |
| AGENT_INSTANCE, AGENT_TOOLS = initialize_agent() | |
| # ============================================ | |
| # API Endpoints | |
| # ============================================ | |
| async def root(): | |
| """Health check and agent info""" | |
| # Extract tool names properly | |
| tool_names = [] | |
| for tool in AGENT_TOOLS: | |
| if hasattr(tool, '__name__'): | |
| tool_names.append(tool.__name__) | |
| elif hasattr(tool, 'name'): | |
| tool_names.append(tool.name) | |
| else: | |
| # Try to extract from string representation | |
| tool_str = str(tool) | |
| if "name='" in tool_str: | |
| try: | |
| name_start = tool_str.index("name='") + 6 | |
| name_end = tool_str.index("'", name_start) | |
| tool_names.append(tool_str[name_start:name_end]) | |
| except: | |
| tool_names.append(str(tool)[:50]) # Truncate long strings | |
| else: | |
| tool_names.append(str(tool)[:50]) | |
| return { | |
| "status": "online", | |
| "agent_name": AGENT_CONFIG.get("name") or AGENT_CONFIG.get("agent_name") or "GenericAgent", | |
| "agent_id": AGENT_CONFIG.get("agent_id"), | |
| "business": AGENT_CONFIG.get("business_context", {}).get("business_name") if isinstance(AGENT_CONFIG.get("business_context"), dict) else None, | |
| "domain": AGENT_CONFIG.get("business_context", {}).get("domain") if isinstance(AGENT_CONFIG.get("business_context"), dict) else AGENT_CONFIG.get("domain"), | |
| "tools_count": len(AGENT_TOOLS), | |
| "tools": tool_names, | |
| "model": AGENT_CONFIG.get("model"), | |
| "deployment": "Hugging Face Space" | |
| } | |
| async def run_agent(request: ChatRequest) -> ChatResponse: | |
| """ | |
| Main endpoint to interact with the agent. | |
| This is the primary interface for users. | |
| """ | |
| import time | |
| try: | |
| # Run the agent | |
| runner = Runner() | |
| temp_session = SQLiteSession(":memory:") | |
| response = await runner.run(AGENT_INSTANCE, request.message, session=temp_session) | |
| final_output = str(response.final_output) if hasattr(response, 'final_output') else str(response) | |
| return ChatResponse( | |
| status="success", | |
| agent_name=AGENT_CONFIG.get("name", "GenericAgent"), | |
| user_message=request.message, | |
| agent_response=final_output, | |
| tools_available=[tool.__name__ if hasattr(tool, '__name__') else str(tool) for tool in AGENT_TOOLS], | |
| timestamp=time.time() | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Agent execution error: {str(e)}") | |
| async def get_config(): | |
| """Get agent configuration (without sensitive data)""" | |
| safe_config = { | |
| "agent_id": AGENT_CONFIG.get("agent_id"), | |
| "name": AGENT_CONFIG.get("name"), | |
| "model": AGENT_CONFIG.get("model"), | |
| "business_context": AGENT_CONFIG.get("business_context"), | |
| "tools_count": len(AGENT_TOOLS), | |
| "deployment_ready": AGENT_CONFIG.get("deployment_ready") | |
| } | |
| return safe_config | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "agent": AGENT_CONFIG.get("name")} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |