diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..08b582d59a3a4c084dbfb83a8d0603bd38a50e43 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,7 @@ +.gradle +.venv +build +.direnv +.idea +examples +coralizer \ No newline at end of file diff --git a/.env b/.env new file mode 100644 index 0000000000000000000000000000000000000000..20270196c15e886b42a2495c4f7e7a81032ae357 --- /dev/null +++ b/.env @@ -0,0 +1,2 @@ +# API Key for Model Provider +API_KEY="" diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..547d3439d73d6376d4370c65e5e2e8d189352acf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +images/thumnail2.png filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000000000000000000000000000000000000..fb2cc28dcfff3a9e28990ad744e33a23b60b5139 --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,28 @@ +name: deploy-docker + +on: + push: + branches: + - 'master' +jobs: + build-and-push-image: + runs-on: ubuntu-latest + permissions: + packages: write + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Log in to the Container registry + uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 + with: + registry: https://ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and push Docker image + uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc + with: + context: . + push: true + tags: ghcr.io/coral-protocol/coral-server:latest diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 0000000000000000000000000000000000000000..7712d8ec4d2b9b9caa2736eff0f9c3975346177e --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,18 @@ +on: + push: + branches: + - master + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main + with: + extra_args: --results=verified,unknown \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b7c3141b2e9898478c67d3d22d42aa052d8fd2aa --- /dev/null +++ b/.gitignore @@ -0,0 +1,58 @@ +.gradle +build/ +!gradle/wrapper/gradle-wrapper.jar +!**/src/main/**/build/ +!**/src/test/**/build/ + +### IntelliJ IDEA ### +.idea/AndroidProjectSystem.xml +.idea/modules.xml +.idea/jarRepositories.xml +.idea/compiler.xml +.idea/libraries/ +.idea/misc.xml +.idea/gradle.xml +.idea/codeStyles/ +*.iws +*.iml +*.ipr +out/ +!**/src/main/**/out/ +!**/src/test/**/out/ + +### Kotlin ### +.kotlin + +### Eclipse ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache +bin/ +!**/src/main/**/bin/ +!**/src/test/**/bin/ + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ + +### VS Code ### +.vscode/ + +### Mac OS ### +.DS_Store + +### Python ### +**/*.venv + +**/*.env + +### Example directories ### +examples/camel-search-maths/__pycache__/ +examples/session-posts/**.env.json \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..13566b81b018ad684f3a35fee301741b2734c8f4 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/kotlinc.xml b/.idea/kotlinc.xml new file mode 100644 index 0000000000000000000000000000000000000000..131e44d79845abd82448872bfe845f2d19c217b0 --- /dev/null +++ b/.idea/kotlinc.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml new file mode 100644 index 0000000000000000000000000000000000000000..31c9bb29202f3160df3f2273815b87a39c07fdf2 --- /dev/null +++ b/.idea/material_theme_project_new.xml @@ -0,0 +1,13 @@ + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..68bab3ea3950ddacfd3a33ab0eb7bdd91e4df01f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,3 @@ +We really appeciate and welcome all contributions! + +By contributing via creating a pull request or pushing directly to this repo, you agree to assign your copyright of the contribution to the Coral Development Team when it is accepted (merged with or without minor changes). You assert that you have full power to assign the copyright, and that any copyright owned by or shared with a third party has been clearly marked with appropriate copyright notices. If you are employed, please check with your employer about the owernership of your contribution. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e6bcc7c202f3e7b464ba1edd08bc9ac9594631df --- /dev/null +++ b/Dockerfile @@ -0,0 +1,25 @@ +FROM gradle:8.14.2-jdk21-alpine AS build +COPY --chown=gradle:gradle . /home/gradle/src +WORKDIR /home/gradle/src + +RUN jlink \ + --verbose \ + --add-modules java.base,jdk.unsupported,java.desktop,java.instrument,java.logging,java.management,java.sql,java.xml \ + --compress 2 --strip-debug --no-header-files --no-man-pages \ + --output /opt/minimal-java + +RUN gradle build --no-daemon -x test + +FROM alpine:3 + +ENV JAVA_HOME=/opt/minimal-java +ENV PATH="$JAVA_HOME/bin:$PATH" + +ENV CONFIG_PATH="/config" + +RUN mkdir /app +# Copy the custom minimal JRE from the builder stage +COPY --from=build "$JAVA_HOME" "$JAVA_HOME" +COPY --from=build /home/gradle/src/build/libs/ /app/ + +ENTRYPOINT ["java","-jar", "/app/coral-server-1.0-SNAPSHOT.jar"] \ No newline at end of file diff --git a/README.md b/README.md index fd10451dc62fc081d46ae27f56c0623a3c0a2b61..17c8a476dc5f9ddbcda45bc4572823aac81c0417 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,143 @@ ---- -title: Corol Server -emoji: 🚀 -colorFrom: purple -colorTo: blue -sdk: docker -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Coral Server - Agent Fuzzy A2A (Agent to Agent) Communication MCP Tools + +An implementation of the Coral protocol that acts as an MCP server providing tools for agents to communicate with each other. +![999](https://github.com/user-attachments/assets/2b74074e-42c2-4abd-9827-ea3c68b75c99) + +## Project Description + +This project implements a Model Context Protocol (MCP) server that facilitates communication between AI agents through a thread-based messaging system. + + +Currently, it provides a set of tools that allow agents to: + +- Register themselves in the system +- Create and manage conversation threads +- Send messages to threads +- Mention other agents in messages +- Receive notifications when mentioned + +The server can be run in different modes (stdio, SSE) to support various integration scenarios. + +![0000](https://github.com/user-attachments/assets/a5227d18-8c57-48b9-877f-97859b176957) + +### Status / future direction +This project is in its early stages and is not yet production-ready. The current focus is on building a robust foundation for agent communication, with plans to add more features and improve performance in the future. + +Right now, this is "Local-mode" only, but we are working on a "Remote-mode" that will allow agents to communicate over the internet. + +For remote mode, we will mostly preserve the interface provided by these MCP tools, but add server configuration options to allow for communicating with remote coral servers to add their agents to the society graph. + +We don't want to re-invent the wheel, so we will reuse existing protocols and standards as much as possible. +Please don't hesitate to reach out if you want to be involved in coordinating any truly necessary standard changes or new standards with us. + +## How to Run + +### Quick example +This repo is a server that enables agents to communicate with each other, for an example of a full multi-agent system using this, check out +[the example here](/examples/camel-search-maths) or for a step-by-step guide to building agentic applications from scratch, follow this tutorial: + [https://github.com/Coral-Protocol/existing-agent-sessions-tutorial-private-temp](https://github.com/Coral-Protocol/existing-agent-sessions-tutorial-private-temp) + +### Demo Video + +[![Coral Server Demo](images/thumnail2.png)](https://youtu.be/MyokByTzY90) +*Click the image above to watch the demo video* + +The project can be run in several modes: + +### Using Gradle + +```bash +# Run with SSE server using Ktor plugin (default, port 5555) +./gradlew run + +# Run with custom arguments +./gradlew run --args="--stdio" +./gradlew run --args="--sse-server 5555" +``` + + +### Using Docker + +Install [Docker](https://docs.docker.com/desktop/) + +```bash +# Build the Docker Image +docker build -t coral-server . + +# Run the Docker Container +docker run -p 5555:5555 -v /path/to/your/coral-server/src/main/resources:/config coral-server +``` + +### Run Modes + +- `--stdio`: Runs an MCP server using standard input/output +- `--sse-server-ktor `: Runs an SSE MCP server using Ktor plugin (default if no argument is provided) +- `--sse-server `: Runs an SSE MCP server with a plain configuration + +## Available Tools + +The server provides the following tools for agent communication: + +### Agent Management +- `list_agents`: List all registered agents + +### Thread Management +- `create_thread`: Create a new thread with participants +- `add_participant`: Add a participant to a thread +- `remove_participant`: Remove a participant from a thread +- `close_thread`: Close a thread with a summary + +### Messaging +- `send_message`: Send a message to a thread +- `wait_for_mentions`: Wait for new messages mentioning an agent + +## Connections (SSE Mode) + +### Coral Server +You can connect to the server on: + +```bash +http://localhost:5555/devmode/exampleApplication/privkey/session1/sse +``` + +### MCP Inspector +You can connect to the server using the MCP Inspector command: + +```bash +npx @modelcontextprotocol/inspector sse --url http://localhost:5555/devmode/exampleApplication/privkey/session1/sse +``` +### Register an Agent +You can register an agent to the Coral Server (also can be registered on MCP inspector) on: + +```bash +http://localhost:5555/devmode/exampleApplication/privkey/session1/sse?agentId=test_agent +``` + + +## Philosophy + +Open infrastructure for the Society of AI Agents + +It's a strange concept; we believe that much of what we now consider work will be handled by a different kind of society—a Society of AI Agents. + +To bridge this gap, Coral Protocol was built as the connective tissue of this society. Coral is designed to enable agents to discover one another, communicate securely, exchange value, and scale their collaborative efforts from any framework. + +We theorize that not only will this fix many problems with the composability of multi-agent systems, but it will also unlock their full potential to be much more capable and safe, this is due to the graph-like structure that prevents any one agent from holding too much power or becoming overwhelmed with too much responsibility. + +## Contribution Guidelines + +We welcome contributions! Email us at [hello@coralprotocol.org](mailto:hello@coralprotocol.org) or join our Discord [here](https://discord.gg/rMQc2uWXhj) to connect with the developer team. Feel free to open issues or submit pull requests. + +Thanks for checking out the project, we hope you like it! + +### Development +IntelliJ IDEA is recommended for development. The project uses Gradle as the build system. + +To clone and import the project: +Go to File > New > Project from Version Control > Git. +enter `git@github.com:Coral-Protocol/coral-server.git` +Click Clone. + +### Running from IntelliJ IDEA +You can click the play button next to the main method in the `Main.kt` file to run the server directly from IntelliJ IDEA. + diff --git a/build.gradle.kts b/build.gradle.kts new file mode 100644 index 0000000000000000000000000000000000000000..619fc27e597674f75927f7f9abc3e3b1ace5031b --- /dev/null +++ b/build.gradle.kts @@ -0,0 +1,125 @@ +plugins { + kotlin("jvm") version "2.1.20" + kotlin("plugin.serialization") version "2.1.20" + application +} + +application { + mainClass.set("org.coralprotocol.coralserver.MainKt") +} + +group = "org.coralprotocol" +version = "1.0-SNAPSHOT" + +repositories { + mavenCentral() + maven { + url = uri("https://central.sonatype.com/repository/maven-snapshots/") + name = "sonatypeSnapshots" + } + + maven("https://repo.repsy.io/mvn/chrynan/public") + maven("https://github.com/CaelumF/schema-kenerator/raw/develop/maven-repo") +} + + +dependencies { + testImplementation(kotlin("test")) + implementation("io.modelcontextprotocol:kotlin-sdk:0.5.0") + implementation("org.slf4j:slf4j-simple:2.0.9") + implementation("io.github.oshai:kotlin-logging-jvm:7.0.3") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.1") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-core:1.8.1") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.8.1") + implementation("com.charleskorn.kaml:kaml:0.78.0") // YAML serialization + implementation("io.github.pdvrieze.xmlutil:core:0.91.0") // XML serialization + implementation("io.github.pdvrieze.xmlutil:serialization:0.91.0") + implementation("io.github.pdvrieze.xmlutil:core-jdk:0.91.0") + implementation("io.github.pdvrieze.xmlutil:serialization-jvm:0.91.0") + implementation("com.github.docker-java:docker-java:3.5.1") + + + // Hoplite for configuration + implementation("com.sksamuel.hoplite:hoplite-core:2.9.0") + implementation("com.sksamuel.hoplite:hoplite-yaml:2.9.0") + + val ktorVersion = "3.0.2" + implementation(enforcedPlatform("io.ktor:ktor-bom:$ktorVersion")) + implementation("io.ktor:ktor-server-status-pages:${ktorVersion}") + + val uriVersion="0.5.0" + implementation("com.chrynan.uri.core:uri-core:$uriVersion") + implementation("com.chrynan.uri.core:uri-ktor-client:$uriVersion") + + // Ktor testing dependencies + testImplementation("io.ktor:ktor-server-test-host") + testImplementation("io.ktor:ktor-client-mock") + val arcVersion = "0.126.0" + // Arc agents for E2E tests + testImplementation("org.eclipse.lmos:arc-agents:$arcVersion") + testImplementation("org.eclipse.lmos:arc-mcp:$arcVersion") + testImplementation("org.eclipse.lmos:arc-server:$arcVersion") + testImplementation("org.eclipse.lmos:arc-azure-client:$arcVersion") + testImplementation("org.eclipse.lmos:arc-langchain4j-client:$arcVersion") + testImplementation("io.modelcontextprotocol.sdk:mcp:0.11.0-SNAPSHOT") // Override MCP Java client for Arc 0.126.0 + testImplementation("io.mockk:mockk:1.14.2") + + // kotest + // TODO: Use kotest for some or all tests +// val kotestVersion = "5.9.1" +// testImplementation("io.kotest:kotest-runner-junit5:$kotestVersion") +// testImplementation("io.kotest:kotest-assertions-core:$kotestVersion") +// testImplementation("io.kotest:kotest-property:$kotestVersion") + + // Ktor client dependencies + implementation("io.ktor:ktor-client-logging") + implementation("io.ktor:ktor-client-content-negotiation") + implementation("io.ktor:ktor-client-cio-jvm") + implementation("io.ktor:ktor-serialization-kotlinx-json") + implementation("io.ktor:ktor-client-plugins") + + implementation("net.pwall.json:json-kotlin-schema:0.56") + + // Ktor server dependencies + implementation("io.ktor:ktor-server-core") + implementation("io.ktor:ktor-server-cio") + implementation("io.ktor:ktor-server-sse") + implementation("io.ktor:ktor-server-html-builder") + implementation("io.ktor:ktor-server-cors") + implementation("io.ktor:ktor-server-content-negotiation") + implementation("io.ktor:ktor-server-resources") + testImplementation("io.ktor:ktor-server-core") + testImplementation("io.ktor:ktor-server-cio") + testImplementation("io.ktor:ktor-server-sse") + + // TOML serialization + implementation("com.akuleshov7:ktoml-core:0.7.0") + implementation("com.akuleshov7:ktoml-file:0.7.0") + + // OpenAPI + val ktorToolsVersion = "5.2.0" + implementation("io.github.smiley4:ktor-openapi:${ktorToolsVersion}") + implementation("io.github.smiley4:ktor-redoc:${ktorToolsVersion}") + + val schemaVersion = "2.4.0.1" + implementation("io.github.smiley4:schema-kenerator-core:${schemaVersion}") + implementation("io.github.smiley4:schema-kenerator-serialization:${schemaVersion}") + implementation("io.github.smiley4:schema-kenerator-swagger:${schemaVersion}") +} + +tasks.test { + useJUnitPlatform() +} + +tasks.jar { + manifest { + attributes["Main-Class"] = "org.coralprotocol.coralserver.MainKt" + } + from(configurations.runtimeClasspath.get().map { if (it.isDirectory) it else zipTree(it) }) + duplicatesStrategy = DuplicatesStrategy.EXCLUDE + exclude("META-INF/*.RSA", "META-INF/*.SF", "META-INF/*.DSA") +} + +kotlin { + jvmToolchain(21) +} diff --git a/examples/camel-resources/camel-interface-resource.py b/examples/camel-resources/camel-interface-resource.py new file mode 100644 index 0000000000000000000000000000000000000000..21a48593835df07f064135639f7980aba580be7b --- /dev/null +++ b/examples/camel-resources/camel-interface-resource.py @@ -0,0 +1,197 @@ +import asyncio +import os +import json +from camel.toolkits.mcp_toolkit import MCPClient +from camel.toolkits import HumanToolkit, MCPToolkit +from camel.models import ModelFactory +from camel.types import ModelPlatformType, ModelType +from camel.agents import ChatAgent +import urllib.parse +import base64 +from mcp import ClientSession +from mcp.types import BlobResourceContents, ResourceContents, TextResourceContents +from typing import Union, Optional, List + +async def get_tools_description(tools): + descriptions = [] + for tool in tools: + tool_name = getattr(tool.func, '__name__', 'unknown_tool') + schema = tool.get_openai_function_schema() or {} + arg_names = list(schema.get('parameters', {}).get('properties', {}).keys()) if schema else [] + description = tool.get_function_description() or 'No description' + schema_str = json.dumps(schema, default=str).replace('{', '{{').replace('}', '}}') + descriptions.append( + f"Tool: {tool_name}, Args: {arg_names}, Description: {description}, Schema: {schema_str}" + ) + return "\n".join(descriptions) + +class SimpleBlob: + """A simple class to hold resource data, MIME type, and metadata.""" + def __init__(self, data: Union[str, bytes], mime_type: Optional[str], metadata: dict): + self.data = data + self.mime_type = mime_type + self.metadata = metadata + + @classmethod + def from_data(cls, data: Union[str, bytes], mime_type: Optional[str] = None, metadata: Optional[dict] = None): + """Create a SimpleBlob from data.""" + return cls(data=data, mime_type=mime_type, metadata=metadata or {}) + +def convert_mcp_resource_to_blob( + resource_uri: str, + contents: ResourceContents, +) -> SimpleBlob: + if isinstance(contents, TextResourceContents): + data = contents.text + elif isinstance(contents, BlobResourceContents): + data = base64.b64decode(contents.blob) + else: + raise ValueError(f"Unsupported content type for URI {resource_uri}") + return SimpleBlob.from_data( + data=data, + mime_type=contents.mimeType, + metadata={"uri": resource_uri}, + ) + +async def get_mcp_resource(session: ClientSession, uri: str) -> List[SimpleBlob]: + contents_result = await session.read_resource(uri) + if not contents_result.contents or len(contents_result.contents) == 0: + return [] + return [ + convert_mcp_resource_to_blob(uri, content) for content in contents_result.contents + ] + +async def load_mcp_resources( + session: ClientSession, + uris: Union[str, List[str], None] = None, +) -> List[SimpleBlob]: + blobs = [] + if uris is None: + resources_list = await session.list_resources() + uri_list = [r.uri for r in resources_list.resources] + elif isinstance(uris, str): + uri_list = [uris] + else: + uri_list = uris + for uri in uri_list: + try: + resource_blobs = await get_mcp_resource(session, uri) + blobs.extend(resource_blobs) + except Exception as e: + print(f"Error fetching resource {uri}: {e}") + continue + return blobs + +async def get_resources( + client: MCPClient, + uris: Union[str, List[str], None] = None +) -> List[SimpleBlob]: + """Get resources from the MCP server. + + Args: + client: MCPClient instance + uris: Optional resource URI or list of URIs to load. If None, fetches all resources. + + Returns: + A list of SimpleBlob objects + """ + if client.session is None: + raise RuntimeError("MCPClient is not connected or session is not initialized.") + try: + return await load_mcp_resources(client.session, uris) + except Exception as e: + raise RuntimeError(f"Error fetching resources: {e}") + +async def main(): + base_url_1 = "http://localhost:5555/devmode/exampleApplication/privkey/session1/sse" + params_1 = { + "waitForAgents": 1, + "agentId": "user_interface_agent", + "agentDescription": "You are user_interaction_agent, responsible for engaging with users, processing instructions, and coordinating with other agents" + } + query_string = urllib.parse.urlencode(params_1) + MCP_SERVER_URL_1 = f"{base_url_1}?{query_string}" + + coral_server = MCPClient( + command_or_url=MCP_SERVER_URL_1, + timeout=300.0 + ) + await coral_server.__aenter__() + print(f"Connected to MCP server as user_interface_agent at {MCP_SERVER_URL_1}") + + model = ModelFactory.create( + model_platform=ModelPlatformType.OPENAI, + model_type=ModelType.GPT_4O_MINI, + api_key=os.getenv("OPENAI_API_KEY"), + model_config_dict={"temperature": 0.3, "max_tokens": 16000}, + ) + + while True: + try: + resources = await get_resources(coral_server, uris=None) + if not resources: + agent_resources = "NA" + print("No resources found.") + else: + agent_resources = "\n".join(str(blob.data) for blob in resources) + print("Resources fetched:") + # for blob in resources: + # print(blob.data) + except Exception as e: + print(f"Error retrieving resources: {e}") + agent_resources = "NA" + + resource_sys_message = agent_resources + + mcp_toolkit = MCPToolkit([coral_server]) + tools = mcp_toolkit.get_tools() + HumanToolkit().get_tools() + tools_description = await get_tools_description(tools) + + sys_msg = ( + f"""You are an agent interacting with the tools from Coral Server and having your own Human Tool to ask have a conversation with Human. + Your resources, provided in `resource_sys_message`, contain thread-based conversations between agents in XML format. + Each thread includes details such as thread ID, participant agent IDs, message content, and timestamps. + Use these resources to understand past agent interactions and inform your decisions when coordinating with other agents or responding to user queries. + + Follow these steps in order: + 1. Use `list_agents` to list all connected agents and get their descriptions. + 2. Use `ask_human_via_console` to ask, "How can I assist you today?" and capture expect response. + 3. Take 2 seconds to think and understand the user's intent and decide the right agent to handle the request based on list of agents. + 4. If the user wants any information about the coral server, use the tools to get the information and pass it to the user. Do not send any message to any other agent, just give the information and go to Step 1. + 5. Once you have the right agent, use `create_thread` to create a thread with the selected agent. If no agent is available, use the `ask_human` tool to specify the agent you want to use. + 6. Use your logic to determine the task you want that agent to perform and create a message for them which instructs the agent to perform the task called "instruction". + 7. Use `send_message` to send a message in the thread, mentioning the selected agent, with content: "instructions". + 8. Use `wait_for_mentions` with a 30 seconds timeout to wait for a response from the agent you mentioned. + 9. Show the entire conversation in the thread to the user. + 10. Wait for 3 seconds and then use `ask_human` to ask the user if they need anything else and keep waiting for their response. + 11. If the user asks for something else, repeat the process from step 1. + + Use only listed tools: {tools_description} + Your resources are: {resource_sys_message}""" + ) + + camel_agent = ChatAgent( + system_message=sys_msg, + model=model, + tools=tools, + ) + print("ChatAgent initialized with updated resources!") + print("Resource System Message before agent question:") + print(resource_sys_message) + + prompt = "As the user_interaction_agent on the Coral Server, initiate your workflow by listing all connected agents and asking the user how you can assist them." + try: + response = await camel_agent.astep(prompt) + print("Agent Reply:") + print(response.msgs[0].content) + except Exception as e: + print(f"Error processing agent response: {e}") + + await asyncio.sleep(3) + + continue + + await coral_server.__aexit__(None, None, None) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/camel-search-maths/README.md b/examples/camel-search-maths/README.md new file mode 100644 index 0000000000000000000000000000000000000000..94f9b3dcf471ca14c66f9e4f8bd4ec83fcd1d437 --- /dev/null +++ b/examples/camel-search-maths/README.md @@ -0,0 +1,91 @@ +In this example we have 3 agents implemented with CAMEL working together to answer a user query. + +To run it, you need to have the dependencies installed: + +# Running the example + +# 1. Install the dependencies +```bash +pip install -r requirements.txt +``` + +You will need to install CAMEL with all optional dependencies until they fix the minimal requirements version. + +```bash +pip install "camel-ai[all]" +``` + +## 2. Start the server +Cd to this project's root directory and run the server. +```bash +./gradlew run +``` + +Note that gradle will show "83"% forever, but it is actually running. You can check the logs in the terminal to see if it is up and running. + +## 3. Run the agents +Ensure you have an OPENAI_API_KEY set in your environment variables (or change to a different model in the agents) + +Before running the agents, you can configure the model settings in `config.py`: +```python +# Model Configuration +PLATFORM_TYPE = "OPENAI" # Change the model provider +MODEL_TYPE = "GPT_4O" # Change the model type + +# Model Settings +MODEL_CONFIG = { + "temperature": 0.3, # Adjust model parameters + "max_tokens": 4096, +} +``` +For available model providers and types, refer to the [CAMEL model types documentation](https://github.com/camel-ai/camel/blob/master/camel/types/enums.py). + +In a separate terminal, run the agents. They all need to be running for this example to work. + +```bash +python mcp_example_camel_math.py +``` + +```bash +python mcp_example_camel_search.py +``` + +```bash +python mcp_example_camel_interface.py +``` + + +## 4. Interact with the agents + +You will eventually see the interface agent asking for your query via STDIN. Write your query and hit enter. +Try asking for example: + +``` +What is the square root of the area of Konstanz? +``` + +The society will then work together to address your query, and the interface agent will share their findings with you. + + +## Troubleshooting +The agents are limited to iterate only 20 times to prevent accidental API expenses, so they might need restarting if they've been alive too long. + +Also right now the agents will not be unregistered, so make sure to restart the server if you want to run them again. + +This is very early, so we welcome any questions no matter how silly they might seem so we can improve the documentation and Dev Experience! + +Come by our Discord for any questions or suggestions: https://discord.gg/cDzGHnzkwD + +--- + + +# Build on the example +Now that you've got your society running, you can build on it. + +Adding another agent is as simple as copying and pasting one of these agent files and running it too. +Don't forget to prompt it to assume a different name. + + +# Future potential +At the time of writing, this is a proof of concept. Server and agent lifecycle questions remain. +The scope of this project includes answering these questions with remote mode and sessions. \ No newline at end of file diff --git a/examples/camel-search-maths/config.py b/examples/camel-search-maths/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8d79e7b3a39f247b5ede5276558618bac166fcad --- /dev/null +++ b/examples/camel-search-maths/config.py @@ -0,0 +1,17 @@ +"""Configuration file for model settings""" + +# Model Configuration +# for more information on the models, see https://github.com/camel-ai/camel/blob/master/camel/types/enums.py + +PLATFORM_TYPE = "OPENAI" +MODEL_TYPE = "GPT_4O" + +# Model Settings +MODEL_CONFIG = { + "temperature": 0.3, + "max_tokens": 4096, +} + +# Agent Settings +MESSAGE_WINDOW_SIZE = 4096 * 50 +TOKEN_LIMIT = 20000 \ No newline at end of file diff --git a/examples/camel-search-maths/interface/coral-agent.toml b/examples/camel-search-maths/interface/coral-agent.toml new file mode 100644 index 0000000000000000000000000000000000000000..98d1e15dd2a4d45ab7507c6c35139d7bb02f4986 --- /dev/null +++ b/examples/camel-search-maths/interface/coral-agent.toml @@ -0,0 +1,10 @@ +[agent] +name = "interface" +version = "0.0.1" + +[options.OPENAI_API_KEY] +type = "string" +description = "OpenAI API Key" + +[runtimes.executable] +command = ["bash", "examples/camel-search-maths/venv.sh", "examples/camel-search-maths/mcp_example_camel_interface.py"] \ No newline at end of file diff --git a/examples/camel-search-maths/interface/marketplace-agent.toml b/examples/camel-search-maths/interface/marketplace-agent.toml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/camel-search-maths/math/coral-agent.toml b/examples/camel-search-maths/math/coral-agent.toml new file mode 100644 index 0000000000000000000000000000000000000000..27c0fb59bd2f347ed110a15dd5a63791b6f641ff --- /dev/null +++ b/examples/camel-search-maths/math/coral-agent.toml @@ -0,0 +1,12 @@ +[agent] +name = "math" +version = "0.0.1" + +[options.OPENAI_API_KEY] +type = "string" +description = "OpenAI API Key" + +#OPENAI_API_KEY = { type = "string", description = "OpenAI API Key" } + +[runtimes.executable] +command = ["bash", "examples/camel-search-maths/venv.sh", "examples/camel-search-maths/mcp_example_camel_math.py"] \ No newline at end of file diff --git a/examples/camel-search-maths/mcp_example_camel_interface.py b/examples/camel-search-maths/mcp_example_camel_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..024cbde0c1413df97fb43cdb9d5d29520c96606e --- /dev/null +++ b/examples/camel-search-maths/mcp_example_camel_interface.py @@ -0,0 +1,74 @@ +import asyncio # Manages asynchronous operations +import os # Provide interaction with the operating system. +from time import sleep + +from camel.agents import ChatAgent # creates Agents +from camel.models import ModelFactory # encapsulates LLM +from camel.toolkits import HumanToolkit, MCPToolkit # import tools +from camel.toolkits.mcp_toolkit import MCPClient +from camel.utils.mcp_client import ServerConfig +from camel.types import ModelPlatformType, ModelType +from dotenv import load_dotenv + +from config import PLATFORM_TYPE, MODEL_TYPE, MODEL_CONFIG, MESSAGE_WINDOW_SIZE, TOKEN_LIMIT + +# load_dotenv() + +from prompts import get_tools_description, get_user_message + +async def main(): + # Simply add the Coral server address as a tool + coral_url = os.getenv("CORAL_CONNECTION_URL", default = "http://localhost:5555/devmode/exampleApplication/privkey/session1/sse?waitForAgents=3&agentId=user_interaction_agent") + server = MCPClient(ServerConfig(url=coral_url, timeout=3000000.0, sse_read_timeout=3000000.0, terminate_on_close=True, prefer_sse=True), timeout=3000000.0) + + mcp_toolkit = MCPToolkit([server]) + + async with mcp_toolkit as connected_mcp_toolkit: + print("Connected to coral server.") + camel_agent = await create_interface_agent(connected_mcp_toolkit) + + # Step the agent continuously + for i in range(20): #This should be infinite, but for testing we limit it to 20 to avoid accidental API fees + resp = await camel_agent.astep(get_user_message()) + msgzero = resp.msgs[0] + msgzerojson = msgzero.to_dict() + print(msgzerojson) + sleep(10) + +async def create_interface_agent(connected_mcp_toolkit): + tools = connected_mcp_toolkit.get_tools() + sys_msg = ( + f""" + You are a helpful assistant responsible for interacting with the user and working with other agents to meet the user's requests. You can interact with other agents using the chat tools. + User interaction is your speciality. You identify as "{os.getenv("CORAL_AGENT_ID", default = "N/A")}". + + As a user interaction agent, only you can interact with the user. Use the user_input tool to get new tasks from the user. + + Make sure that all information comes from reliable sources and that all calculations are done using the appropriate tools by the appropriate agents. Make sure your responses are much more reliable than guesses! You should make sure no agents are guessing too, by suggesting the relevant agents to do each part of a task to the agents you are working with. Do a refresh of the available agents before asking the user for input. + + Make sure to put the name of the agent(s) you are talking to in the mentions field of the send message tool. + + {os.getenv("CORAL_PROMPT_SYSTEM", default = "")} + + Here are the guidelines for using the communication tools: + {get_tools_description()} + """ + ) + model = ModelFactory.create( + model_platform=ModelPlatformType[PLATFORM_TYPE], + model_type=ModelType[MODEL_TYPE], + api_key=os.getenv("API_KEY"), + model_config_dict=MODEL_CONFIG, + ) + camel_agent = ChatAgent( # create agent with our mcp tools + system_message=sys_msg, + model=model, + tools=tools, + message_window_size=MESSAGE_WINDOW_SIZE, + token_limit=TOKEN_LIMIT + ) + return camel_agent + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/camel-search-maths/mcp_example_camel_math.py b/examples/camel-search-maths/mcp_example_camel_math.py new file mode 100644 index 0000000000000000000000000000000000000000..36255c8ebc53325a2ca8248e7e807a6f46c2e629 --- /dev/null +++ b/examples/camel-search-maths/mcp_example_camel_math.py @@ -0,0 +1,68 @@ +import asyncio +import os +from time import sleep + +from camel.agents import ChatAgent +from camel.models import ModelFactory +from camel.toolkits import MCPToolkit, MathToolkit +from camel.utils.mcp_client import ServerConfig +from camel.toolkits.mcp_toolkit import MCPClient +from camel.types import ModelPlatformType, ModelType +from prompts import get_tools_description, get_user_message +from dotenv import load_dotenv +from config import PLATFORM_TYPE, MODEL_TYPE, MODEL_CONFIG, MESSAGE_WINDOW_SIZE, TOKEN_LIMIT + +# load_dotenv() + +async def main(): + # Simply add the Coral server address as a tool + print("Starting MCP client...") + coral_url = os.getenv("CORAL_CONNECTION_URL", default = "http://localhost:5555/devmode/exampleApplication/privkey/session1/sse?agentId=math_agent") + server = MCPClient(ServerConfig(url=coral_url, timeout=3000000.0, sse_read_timeout=3000000.0, terminate_on_close=True, prefer_sse=True), timeout=3000000.0) + mcp_toolkit = MCPToolkit([server]) + + async with mcp_toolkit as connected_mcp_toolkit: + tools = connected_mcp_toolkit.get_tools() + MathToolkit().get_tools() + camel_agent = await create_math_agent(tools) + + # Step the agent continuously + for i in range(20): #This should be infinite, but for testing we limit it to 20 to avoid accidental API fees + resp = await camel_agent.astep(get_user_message()) + msgzero = resp.msgs[0] + msgzerojson = msgzero.to_dict() + print(msgzerojson) + sleep(10) + + +async def create_math_agent(tools): + sys_msg = ( + f""" + You are a helpful assistant responsible for doing maths + operations. You can interact with other agents using the chat tools. + Mathematics are your speciality. You identify as "math_agent". + + If you have no tasks yet, call the wait for mentions tool. Don't ask agents for tasks, wait for them to ask you. + {os.getenv("CORAL_PROMPT_SYSTEM", default = "")} + + Here are the guidelines for using the communication tools: + {get_tools_description()} + """ + ) + model = ModelFactory.create( + model_platform=ModelPlatformType[PLATFORM_TYPE], + model_type=ModelType[MODEL_TYPE], + api_key=os.getenv("API_KEY"), + model_config_dict=MODEL_CONFIG, + ) + camel_agent = ChatAgent( + system_message=sys_msg, + model=model, + tools=tools, + message_window_size=MESSAGE_WINDOW_SIZE, + token_limit=TOKEN_LIMIT + ) + return camel_agent + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/camel-search-maths/mcp_example_camel_search.py b/examples/camel-search-maths/mcp_example_camel_search.py new file mode 100644 index 0000000000000000000000000000000000000000..c74fd2d68f482a72c39e127f8de5b376735cc6de --- /dev/null +++ b/examples/camel-search-maths/mcp_example_camel_search.py @@ -0,0 +1,76 @@ +import asyncio +import os +from time import sleep + +from camel.agents import ChatAgent +from camel.models import ModelFactory +from camel.toolkits import FunctionTool, MCPToolkit +from camel.toolkits.mcp_toolkit import MCPClient +from camel.utils.mcp_client import ServerConfig +from camel.toolkits.search_toolkit import SearchToolkit +from camel.types import ModelPlatformType, ModelType + +from prompts import get_tools_description, get_user_message +from tools import JinaBrowsingToolkit +from dotenv import load_dotenv +from config import PLATFORM_TYPE, MODEL_TYPE, MODEL_CONFIG, MESSAGE_WINDOW_SIZE, TOKEN_LIMIT + +# load_dotenv() + +async def main(): + # Simply add the Coral server address as a tool + coral_url = os.getenv("CORAL_CONNECTION_URL", default = "http://localhost:5555/devmode/exampleApplication/privkey/session1/sse?waitForAgents=3&agentId=search_agent") + server = MCPClient(ServerConfig(url=coral_url, timeout=3000000.0, sse_read_timeout=3000000.0, terminate_on_close=True, prefer_sse=True), timeout=3000000.0) + mcp_toolkit = MCPToolkit([server]) + + async with mcp_toolkit as connected_mcp_toolkit: + camel_agent = await create_search_agent(connected_mcp_toolkit) + + # Step the agent continuously + for i in range(20): #This should be infinite, but for testing we limit it to 20 to avoid accidental API fees + resp = await camel_agent.astep(get_user_message()) + msgzero = resp.msgs[0] + msgzerojson = msgzero.to_dict() + print(msgzerojson) + sleep(10) + + +async def create_search_agent(connected_mcp_toolkit): + search_toolkit = SearchToolkit() + browse_toolkit = JinaBrowsingToolkit() + search_tools = [ + FunctionTool(search_toolkit.search_google), + FunctionTool(browse_toolkit.get_url_content), + FunctionTool(browse_toolkit.get_url_content_with_context), + ] + tools = connected_mcp_toolkit.get_tools() + search_tools + sys_msg = ( + f""" + You are a helpful assistant responsible for doing search operations. You can interact with other agents using the chat tools. + Search is your speciality. You identify as "search_agent". + + If you have no tasks yet, call the wait for mentions tool. Don't ask agents for tasks, wait for them to ask you. + {os.getenv("CORAL_PROMPT_SYSTEM", default = "")} + + Here are the guidelines for using the communication tools: + {get_tools_description()} + """ + ) + model = ModelFactory.create( + model_platform=ModelPlatformType[PLATFORM_TYPE], + model_type=ModelType[MODEL_TYPE], + api_key=os.getenv("API_KEY"), + model_config_dict=MODEL_CONFIG, + ) + camel_agent = ChatAgent( + system_message=sys_msg, + model=model, + tools=tools, + message_window_size=MESSAGE_WINDOW_SIZE, + token_limit=TOKEN_LIMIT + ) + return camel_agent + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/camel-search-maths/prompts.py b/examples/camel-search-maths/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..8c358342c9fba08451206b18f3a2a924867c3af2 --- /dev/null +++ b/examples/camel-search-maths/prompts.py @@ -0,0 +1,23 @@ +def get_tools_description(): + return """ +You have access to communication tools to interact with other agents. + +Before using the tools, you need to register yourself using the register tool. Name yourself with a name that describes your speciality well. Do not be too generic. For example, if you are a search agent, you can name yourself "search_agent". + +If there are no other agents, remember to re-list the agents periodically using the list tool. + +You should know that the user can't see any messages you send, you are expected to be autonomous and respond to the user only when you have finished working with other agents, using tools specifically for that. + +You can emit as many messages as you like before using that tool when you are finished or absolutely need user input. You are on a loop and will see a "user" message every 4 seconds, but it's not really from the user. + +When sending messages, you MUST put the name of the agent(s) you are talking to in the mentions field of the send message tool. If you don't mention anybody, nobody will receive it! + +Run the wait for mention tool when you are ready to receive a message from another agent. This is the preferred way to wait for messages from other agents. + +You'll only see messages from other agents since you last called the wait for mention tool. Remember to call this periodically. Also call this when you're waiting with nothing to do. + +Don't try to guess any numbers or facts, only use reliable sources. If you are unsure, ask other agents for help. + """ + +def get_user_message(): + return "[automated] continue collaborating with other agents. make sure to mention agents you intend to communicate with" diff --git a/examples/camel-search-maths/requirements.txt b/examples/camel-search-maths/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2bcc7d5351e1d3406ff7ac1f7012f68bcaae1cfc --- /dev/null +++ b/examples/camel-search-maths/requirements.txt @@ -0,0 +1,3 @@ +requests==2.32.3 +camel-ai==0.2.46 +asyncio==3.4.3 diff --git a/examples/camel-search-maths/search/coral-agent.toml b/examples/camel-search-maths/search/coral-agent.toml new file mode 100644 index 0000000000000000000000000000000000000000..b15fa7678c52c9b9379e5d3e9b284174aafd2270 --- /dev/null +++ b/examples/camel-search-maths/search/coral-agent.toml @@ -0,0 +1,11 @@ +[agent] +name = "search" +version = "0.0.1" + +[options] +OPENAI_API_KEY = { type = "string", description = "OpenAI API Key" } +GOOGLE_API_KEY = { type = "string", description = "Google API Key" } +SEARCH_ENGINE_ID = { type = "string", description = "Google Search Engine ID" } + +[runtimes.executable] +command = ["bash", "examples/camel-search-maths/venv.sh", "examples/camel-search-maths/mcp_example_camel_math.py"] \ No newline at end of file diff --git a/examples/camel-search-maths/tools.py b/examples/camel-search-maths/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..def1e6a9e81e1dc657acbf47c9712ebd96e17a27 --- /dev/null +++ b/examples/camel-search-maths/tools.py @@ -0,0 +1,85 @@ +import os + +import requests +from camel.toolkits import BaseToolkit + + +class JinaBrowsingToolkit(BaseToolkit): + def get_url_content(self, url: str) -> str: + r"""Fetch the content of a URL using the r.jina.ai service. + + Args: + url (str): The URL to fetch content from. + + Returns: + str: The markdown content of the URL. + """ + + # Replace http with https and add https if not present + if not url.startswith("https://"): + url = "https://" + url.lstrip("https://").lstrip("http://") + + jina_url = f"https://r.jina.ai/{url}" + headers = {} + if os.environ.get('JINA_PROXY_URL'): + headers['X-Proxy-Url'] = os.environ.get('JINA_PROXY_URL') + + auth_token = os.environ.get('JINA_AUTH_TOKEN') + if auth_token: + headers['Authorization'] = f'Bearer {auth_token}' + try: + response = requests.get(jina_url, headers=headers) + response.raise_for_status() + return response.text + except requests.RequestException as e: + return f"Error fetching URL content: {e!s}" + + def get_url_content_with_context( + self, + url: str, + search_string: str, + context_chars: int = 700, + max_instances: int = 3, + ) -> str: + r"""Fetch the content of a URL and return context around all instances of a specific string. + + Args: + url (str): The URL to fetch content from. + search_string (str): The string to search for in the content. + context_chars (int): Number of characters to return before and after each found string. + max_instances (int): Maximum number of instances to return. + + Returns: + str: The context around all found instances of the string, or an error message if not found. + + If there are no results, try again with a more likely search string. Start with a more likely string and only use a less likely string if the first one has too many results. + """ + content = self.get_url_content(url) + if content.startswith("Error fetching URL content"): + return content + + instances = [] + start = 0 + while True: + index = content.lower().find(search_string.lower(), start) + if index == -1 or len(instances) >= max_instances: + break + + context_start = max(0, index - context_chars) + context_end = min( + len(content), index + len(search_string) + context_chars + ) + instance_context = content[context_start:context_end] + instances.append( + f"Instance {len(instances) + 1}:\n{instance_context}\n" + ) + + start = index + len(search_string) + + if instances: + return ( + f"Found {len(instances)} instance(s) of '{search_string}':\n\n" + + '\n'.join(instances) + ) + else: + return f"Search string '{search_string}' not found in the content." diff --git a/examples/camel-search-maths/venv.sh b/examples/camel-search-maths/venv.sh new file mode 100644 index 0000000000000000000000000000000000000000..d4f4032bc137aa038c030f5d196134aeb1355a19 --- /dev/null +++ b/examples/camel-search-maths/venv.sh @@ -0,0 +1,2 @@ +source examples/camel-search-maths/.venv/bin/activate +python -u "$1" \ No newline at end of file diff --git a/examples/session-posts/Local docker session.http b/examples/session-posts/Local docker session.http new file mode 100644 index 0000000000000000000000000000000000000000..766b95eb166882add3f7469f224ad76ede33a092 --- /dev/null +++ b/examples/session-posts/Local docker session.http @@ -0,0 +1,41 @@ +POST http://localhost:5555/sessions +Content-Type: application/json + +{ + "applicationId": "app", + "privacyKey": "priv", + "agentGraph": { + "agents": { + "my-deepresearch": { + "type": "local", + "agentType": "deepresearch", + "options": { + "OPENAI_API_KEY": "{{OPENAI_API_KEY}}", + "LINKUP_API_KEY": "{{LINKUP_API_KEY}}" + } + }, + "my-repounderstanding": { + "type": "local", + "agentType": "repounderstanding", + "options": { + "OPENAI_API_KEY": "{{OPENAI_API_KEY}}", + "GITHUB_ACCESS_TOKEN": "{{GITHUB_ACCESS_TOKEN}}" + } + }, + "my-interface": { + "type": "local", + "agentType": "interface", + "options": { + "OPENAI_API_KEY": "{{OPENAI_API_KEY}}", + "HUMAN_RESPONSE": "Please give me a comprehensive instruction of the master branch of Coral-Protocol/coral-server." + } + } + }, + "links": [ + ["my-repounderstanding", "my-deepresearch", "my-interface"] + ] + } +} + + +### \ No newline at end of file diff --git a/examples/session-posts/application.yaml b/examples/session-posts/application.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f432ceca75bae72ebb1c3aef3c5fac12bdb3a103 --- /dev/null +++ b/examples/session-posts/application.yaml @@ -0,0 +1,84 @@ +# Default application configuration +# TODO: Applications are a work in progress. This is safe to ignore for now. +applications: + - id: "app" + name: "Default Application" + description: "Default application for testing" + privacyKeys: + - "default-key" + - "public" + - "priv" + +# NOTE: this will almost certainly *not* work on your machine without some tweaking of the `command`'s for each agent + +# Registry of agents we can orchestrate +registry: + # test: + # options: + # - name: "NAME" + # type: "string" + # description: "Test agent name" + # runtime: + # type: "executable" + # command: ["bash", "examples/camel-search-maths/venv.sh", "examples/camel-search-maths/test.py"] + # environment: + # - option: "NAME" + repounderstanding: + # Exposed configuration for consumers of this agent + options: + - name: "OPENAI_API_KEY" + type: "string" + description: "OpenAI API Key" + - name: "GITHUB_ACCESS_TOKEN" + type: "string" + description: "GitHub Access Token" + + # How this agent is actually orchestrated locally + runtime: + type: "docker" + environment: + - name: "API_KEY" + from: "OPENAI_API_KEY" + - name: "GITHUB_ACCESS_TOKEN" + from: "GITHUB_ACCESS_TOKEN" + image: "sd2879/coral-repounderstanding:latest" + + deepresearch: + options: + - name: "OPENAI_API_KEY" + type: "string" + description: "OpenAI API Key" + - name: "LINKUP_API_KEY" + type: "string" + description: "LinkUp API Key. Get from https://linkup.so/" + + runtime: + type: "docker" + environment: + - name: "API_KEY" + from: "OPENAI_API_KEY" + image: "sd2879/coral-opendeepresearch:latest" + + interface: + options: + - name: "OPENAI_API_KEY" + type: "string" + description: "OpenAI API Key" + - name: "HUMAN_RESPONSE" + type: "string" + description: "Human response to be used in the interface agent" + + runtime: + type: "docker" + image: "sd2879/coral-interface-agent:latest" + environment: + - name: "API_KEY" + from: "OPENAI_API_KEY" + - name: "HUMAN_RESPONSE" + from: "HUMAN_RESPONSE" + +# Uncomment to configure an external application source +# applicationSource: +# type: "http" +# url: "https://example.com/applications" +# refreshIntervalSeconds: 3600 \ No newline at end of file diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..a157a028609397f8059ba55e723d07cfd0fe1d7e --- /dev/null +++ b/gradle.properties @@ -0,0 +1,2 @@ +kotlin.code.style=official +org.gradle.console=plain \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..249e5832f090a2944b7473328c07c9755baa3196 Binary files /dev/null and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..2a4c20598d6ff383539c8f895f2f0609f163bb4a --- /dev/null +++ b/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,8 @@ +#Fri Mar 21 11:39:48 GMT 2025 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +org.gradle.java.installations.fromEnv=JAVA_HOME +org.gradle.java.installations.auto-download=true diff --git a/gradlew b/gradlew new file mode 100644 index 0000000000000000000000000000000000000000..1b6c787337ffb79f0e3cf8b1e9f00f680a959de1 --- /dev/null +++ b/gradlew @@ -0,0 +1,234 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit + +APP_NAME="Gradle" +APP_BASE_NAME=${0##*/} + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat new file mode 100644 index 0000000000000000000000000000000000000000..107acd32c4e687021ef32db511e8a206129b88ec --- /dev/null +++ b/gradlew.bat @@ -0,0 +1,89 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/images/thumnail2.png b/images/thumnail2.png new file mode 100644 index 0000000000000000000000000000000000000000..90ce7e5213c8b19a822ee31d489d00f7881e6eb3 --- /dev/null +++ b/images/thumnail2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0e15fb7815341b88995b4484edc0f4a35f4312b233437d14431b4eb54e7cb2c +size 1654340 diff --git a/settings.gradle.kts b/settings.gradle.kts new file mode 100644 index 0000000000000000000000000000000000000000..ed340cc073d76d9a9d13d5113c038383e46d6ed8 --- /dev/null +++ b/settings.gradle.kts @@ -0,0 +1,5 @@ +plugins { + id("org.gradle.toolchains.foojay-resolver-convention") version "0.8.0" +} +rootProject.name = "coral-server" + diff --git a/src/main/kotlin/org/coralprotocol/coralserver/EventBus.kt b/src/main/kotlin/org/coralprotocol/coralserver/EventBus.kt new file mode 100644 index 0000000000000000000000000000000000000000..963270151c24167892adafe56a3d6690331a270d --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/EventBus.kt @@ -0,0 +1,13 @@ +package org.coralprotocol.coralserver + +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.asSharedFlow + +class EventBus(val replay: Int = 0) { + private val _events = MutableSharedFlow(extraBufferCapacity = 1024, replay = replay) // private mutable shared flow + val events = _events.asSharedFlow() // publicly exposed as read-only shared flow + + fun emit(event: E) { + _events.tryEmit(event) // suspends until all subscribers receive it + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/IterHelpers.kt b/src/main/kotlin/org/coralprotocol/coralserver/IterHelpers.kt new file mode 100644 index 0000000000000000000000000000000000000000..0c97f733bef4c0c33b9d71996ec6f97bfd589aa0 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/IterHelpers.kt @@ -0,0 +1,10 @@ +package org.coralprotocol.coralserver + +fun List>.toMapOnDuplicate(onDuplicates: (duplicates: List) -> Unit): Map { + val groups: Map>> = groupBy { it.first } + val duplicates = groups.filter { it.value.size > 1 }.map { it.key } + if (duplicates.isNotEmpty()) { + onDuplicates(duplicates) + } + return groups.mapValues { it.value.first().second } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/Main.kt b/src/main/kotlin/org/coralprotocol/coralserver/Main.kt new file mode 100644 index 0000000000000000000000000000000000000000..c392430053bb793af7408c499611abc1187f85db --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/Main.kt @@ -0,0 +1,57 @@ +package org.coralprotocol.coralserver + +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.runBlocking +import org.coralprotocol.coralserver.config.ConfigCollection +import org.coralprotocol.coralserver.agent.runtime.Orchestrator +import org.coralprotocol.coralserver.server.CoralServer +import org.coralprotocol.coralserver.session.SessionManager + +private val logger = KotlinLogging.logger {} + +/** + * Start sse-server mcp on port 5555. + * + * @param args + * - "--stdio": Runs an MCP server using standard input/output. + * - "--sse-server ": Runs an SSE MCP server with a plain configuration. + * - "--dev": Runs the server in development mode. + */ +fun main(args: Array) { +// System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "TRACE"); +// System.setProperty("io.ktor.development", "true") + + val command = args.firstOrNull() ?: "--sse-server" + val port = args.getOrNull(1)?.toUShortOrNull() ?: 5555u + val devMode = args.contains("--dev") + + when (command) { +// "--stdio" -> runMcpServerUsingStdio() + "--sse-server" -> { + val appConfig = ConfigCollection() + + val orchestrator = Orchestrator(appConfig) + val server = CoralServer( + port = port, + devmode = devMode, + appConfig = appConfig, + sessionManager = SessionManager(orchestrator, port = port) + ) + + // Add shutdown hook to stop the server gracefully + Runtime.getRuntime().addShutdownHook(Thread { + logger.info { "Shutting down server..." } + appConfig.stopWatch() + server.stop() + runBlocking { + orchestrator.destroy() + } + }) + + server.start(wait = true) + } + else -> { + logger.error { "Unknown command: $command" } + } + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgent.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgent.kt new file mode 100644 index 0000000000000000000000000000000000000000..1da84a374fcedb034953a3781374e43f75f2fc59 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgent.kt @@ -0,0 +1,43 @@ +package org.coralprotocol.coralserver.agent.graph + +import io.github.smiley4.schemakenerator.core.annotations.Description +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.agent.registry.AgentOptionValue +import org.coralprotocol.coralserver.session.CustomTool + +@Serializable +@Description("The representation of an agent on the agent graph. This refers to a registry agent by name") +data class GraphAgent( + @Description("The name of the agent in the registry") + val name: String, + + @Description("The options that are passed to the agent") + val options: Map, + + @Description("The system prompt/developer text/preamble passed to the agent") + val systemPrompt: String?, + + @Description("") + val extraTools: Set, + + @Description("") + val blocking: Boolean, + + @Description("The provider for this agent") + val provider: GraphAgentProvider +) + +@Serializable +@Description("A graph of agents, tools and links between them. The agent links define agent groups") +data class AgentGraph( + @Description("A map of agent names to graph agents") + val agents: Map, + + @Description("") + val tools: Map, + + @Description("") + val links: Set>, +) + diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentProvider.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentProvider.kt new file mode 100644 index 0000000000000000000000000000000000000000..f309444a7f9fee185b8a0f43320ea27aef449008 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentProvider.kt @@ -0,0 +1,36 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.agent.graph + +import io.github.smiley4.schemakenerator.core.annotations.Description +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator +import org.coralprotocol.coralserver.agent.runtime.RuntimeId + +@Serializable +@JsonClassDiscriminator("type") +@Description("A local or remote provider for an agent") +sealed class GraphAgentProvider { + @Serializable + @SerialName("local") + @Description("The agent will be provided by this server") + data class Local( + val runtime: RuntimeId, + ) : GraphAgentProvider() + + @Serializable + @SerialName("remote") + @Description("Agent will be provided by another Coral server") + data class Remote( + @Description("The runtime that should be used for this remote agent. Servers can export only specific runtimes so the runtime choice may narrow servers that can adequately provide the agent") + val runtime: RuntimeId, + + @Description("A description of which servers should be queried for this remote agent request") + val serverSource: GraphAgentServerSource, + + @Description("Customisation for the scoring of servers") + val serverScoring: GraphAgentServerScoring? = GraphAgentServerScoring.Default() + ) : GraphAgentProvider() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentRequest.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentRequest.kt new file mode 100644 index 0000000000000000000000000000000000000000..870bb4b601d6796da327eebe91638136ba52b9e5 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentRequest.kt @@ -0,0 +1,27 @@ +package org.coralprotocol.coralserver.agent.graph + +import io.github.smiley4.schemakenerator.core.annotations.Description +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.agent.registry.AgentOptionValue + +@Serializable +@Description("A request for an agent") +data class GraphAgentRequest( + @Description("The name of the agent to run, this must match the name of the agent in the registry") + val agentName: String, + + @Description("The arguments to pass to the agent") + val options: Map, + + @Description("The system prompt/developer text/preamble passed to the agent") + val systemPrompt: String?, + + @Description("") + val blocking: Boolean?, + + @Description("") + val tools: Set, + + @Description("The server that should provide this agent and the runtime to use") + val provider: GraphAgentProvider +) diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServer.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServer.kt new file mode 100644 index 0000000000000000000000000000000000000000..46a414f5860b756a1d9bbf3e20eeaa7a4cfc928d --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServer.kt @@ -0,0 +1,54 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.agent.graph + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator + +@Serializable +data class GraphAgentServer( + val address: String, + val attributes: List +) + +@Serializable +enum class GraphAgentServerAttributeType { + // possibly represented as a timezone + @SerialName("geographic_location") + GEOGRAPHIC_LOCATION, + + // wallet ID? + @SerialName("attested_by") + ATTESTED_BY, + + // todo: fill this out +} + +@Serializable +@JsonClassDiscriminator("format") +sealed class GraphAgentServerAttribute() { + abstract val type: GraphAgentServerAttributeType + + @Serializable + @SerialName("string") + data class String( + override val type: GraphAgentServerAttributeType, + val value: kotlin.String + ) : GraphAgentServerAttribute() + + @Serializable + @SerialName("number") + data class Number( + override val type: GraphAgentServerAttributeType, + val value: Double + ) : GraphAgentServerAttribute() + + @Serializable + @SerialName("boolean") + data class Boolean( + override val type: GraphAgentServerAttributeType, + val value: kotlin.Boolean + ) : GraphAgentServerAttribute() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServerScoring.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServerScoring.kt new file mode 100644 index 0000000000000000000000000000000000000000..0e6562969022a092fb275645c61a0d6087aa308b --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServerScoring.kt @@ -0,0 +1,156 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.agent.graph + +import io.github.smiley4.schemakenerator.core.annotations.Description +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator + +@Serializable +sealed class GraphAgentServerScorerEffect { + @Serializable + @SerialName("flat") + @Description("A flat negative or positive weight") + data class Flat(val weight: Double) : GraphAgentServerScorerEffect() + + @Serializable + @SerialName("multiplier") + @Description("A multiplier weight, this effect will only multiply against attributes with a number type") + data class Multiplier(val weight: Double) : GraphAgentServerScorerEffect() + + fun apply(value: Double = 0.0): Double { + return when (this) { + is Flat -> { + weight + } + is Multiplier -> { + value * weight + } + } + } +} + +@Serializable +@JsonClassDiscriminator("op") +sealed class GraphAgentServerCustomScorer { + @Serializable + @SerialName("is_true") + @Description("The effect will be applied for every attribute of the specified type with a boolean true value") + data class IsTrue( + val type: GraphAgentServerAttributeType, + val effect: GraphAgentServerScorerEffect.Flat + ) : GraphAgentServerCustomScorer() + + @Serializable + @SerialName("is_false") + @Description("The effect will be applied for every attribute of the specified type with a boolean false value") + data class IsFalse( + val type: GraphAgentServerAttributeType, + val effect: GraphAgentServerScorerEffect.Flat + ) : GraphAgentServerCustomScorer() + + @Serializable + @SerialName("is_present") + @Description("The effect will be applied for every attribute of the specified type") + data class IsPresent( + val type: GraphAgentServerAttributeType, + val effect: GraphAgentServerScorerEffect + ) : GraphAgentServerCustomScorer() + + @Serializable + @SerialName("is_not_present") + @Description("The effect will be applied if the no attribute of the specified type is present") + data class IsNotPresent( + val type: GraphAgentServerAttributeType, + val effect: GraphAgentServerScorerEffect.Flat + ) : GraphAgentServerCustomScorer() + + @Serializable + @SerialName("string_equal") + @Description("The effect will be applied for every attribute of the specified type with a matching string value") + data class StringEqual( + val type: GraphAgentServerAttributeType, + val string: String, + val effect: GraphAgentServerScorerEffect.Flat + ) : GraphAgentServerCustomScorer() + + @Serializable + @SerialName("string_not_equal") + @Description("The effect will be applied for every attribute of the specified type with a non-matching string value") + data class StringNotEqual( + val type: GraphAgentServerAttributeType, + val string: String, + val effect: GraphAgentServerScorerEffect.Flat + ) : GraphAgentServerCustomScorer() + + fun getScore(server: GraphAgentServer): Double = + when (this) { + is IsTrue -> { + server.attributes.filter { + it.type == type && it is GraphAgentServerAttribute.Boolean && it.value + }.sumOf { effect.apply() } + } + is IsFalse -> { + server.attributes.filter { + it.type == type && it is GraphAgentServerAttribute.Boolean && !it.value + }.sumOf { effect.apply() } + } + is IsNotPresent -> { + val condition = server.attributes.firstOrNull { + it.type == type + } == null + + if (condition) effect.apply() else 0.0 + } + is IsPresent -> { + server.attributes.filter { + it.type == type && it is GraphAgentServerAttribute.Boolean && !it.value + }.sumOf { + when (it) { + is GraphAgentServerAttribute.Number -> effect.apply(it.value) + else -> effect.apply() + } + } + } + is StringEqual -> { + server.attributes.filter { + it.type == type && it is GraphAgentServerAttribute.String && it.value == string + }.sumOf { effect.apply() } + } + is StringNotEqual -> { + server.attributes.filter { + it.type == type && it is GraphAgentServerAttribute.String && it.value != string + }.sumOf { effect.apply() } + } + } +} + +@Serializable +@JsonClassDiscriminator("type") +sealed class GraphAgentServerScoring() { + abstract fun getScore(server: GraphAgentServer): Double + + @Serializable + @SerialName("custom") + @Description("Custom server scoring. Weights can be added on a flat or multiplier basis per attribute") + data class Custom( + val scorers: List + ) : GraphAgentServerScoring() { + override fun getScore(server: GraphAgentServer): Double { + return scorers.sumOf { it.getScore(server) } + } + } + + @Serializable + @SerialName("default") + @Description("Default server scoring. No weights assigned to any server attribute") + class Default : GraphAgentServerScoring() { + override fun getScore(server: GraphAgentServer): Double { + return 1.0 + } + } + + // todo: better defaults/presets +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServerSource.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServerSource.kt new file mode 100644 index 0000000000000000000000000000000000000000..1f5244af2f70c0aab604803adbf65bb8e7680e6e --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/graph/GraphAgentServerSource.kt @@ -0,0 +1,21 @@ +package org.coralprotocol.coralserver.agent.graph + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +sealed class GraphAgentServerSource { + @Serializable + @SerialName("servers") + data class Servers( + val servers: List + ) : GraphAgentServerSource() + + // TODO: implement this properly! + // an indexer will be a server that will provide another list of servers to query. We will allow people to host + // their own indexers and we will also provide an indexer connected to our agent marketplace. + @Serializable + data class Indexer( + val indexer: String + ) : GraphAgentServerSource() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentExport.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentExport.kt new file mode 100644 index 0000000000000000000000000000000000000000..fcb76a421f53bb88350fbee6d208526b299d65e5 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentExport.kt @@ -0,0 +1,36 @@ +package org.coralprotocol.coralserver.agent.registry + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.agent.runtime.RuntimeId + +@Serializable +data class AgentExportPricing( + @SerialName("min_price") + val minPrice: Double, + + @SerialName("max_price") + val maxPrice: Double, +) + +@Serializable +data class AgentExport( + val agent: RegistryAgent, + val runtimes: Map, + val quantity: UInt +) + +@Serializable +data class PublicAgentExport( + val agent: PublicRegistryAgent, + val runtimes: Map, + val quantity: UInt +) + +fun AgentExport.toPublic(id: String): PublicAgentExport { + return PublicAgentExport( + agent = agent.toPublic(id), + runtimes = runtimes, + quantity = quantity + ) +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentOption.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentOption.kt new file mode 100644 index 0000000000000000000000000000000000000000..82d3e0365e79bbbcb65490559fa8aabe2aaaa051 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentOption.kt @@ -0,0 +1,60 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.agent.registry + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator + +@Serializable +enum class AgentOptionType { + @SerialName("string") + STRING, + + @SerialName("secret") + SECRET, + + @SerialName("number") + NUMBER, +} + +@Serializable +@JsonClassDiscriminator("type") +sealed class AgentOption { + abstract val description: kotlin.String? + abstract val required: Boolean + + @Serializable + @SerialName("string") + data class String( + override val description: kotlin.String? = null, + val default: kotlin.String? = null + ) : AgentOption() { + override val required: Boolean = default == null + } + + @Serializable + @SerialName("number") + data class Number( + override val description: kotlin.String? = null, + val default: Double? = null, + ) : AgentOption() { + override val required: Boolean = default == null + } + + @Serializable + @SerialName("secret") + data class Secret( + override val description: kotlin.String? = null, + ) : AgentOption() { + override val required: Boolean = true + } +} + +fun AgentOption.defaultAsValue(): AgentOptionValue? = + when (this) { + is AgentOption.String -> this.default?.let { AgentOptionValue.String(it) } + is AgentOption.Number -> this.default?.let { AgentOptionValue.Number(it) } + is AgentOption.Secret -> null + } diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentOptionValue.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentOptionValue.kt new file mode 100644 index 0000000000000000000000000000000000000000..3143ff5d7e13168fa7490829cafd62d865af2672 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentOptionValue.kt @@ -0,0 +1,27 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.agent.registry + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator + +@Serializable +@JsonClassDiscriminator("type") +sealed class AgentOptionValue { + + @Serializable + @SerialName("string") + data class String(val value: kotlin.String) : AgentOptionValue() + + + @Serializable + @SerialName("number") + data class Number(val value: Double) : AgentOptionValue() +} + +fun AgentOptionValue.toStringValue(): String = when (this) { + is AgentOptionValue.String -> value + is AgentOptionValue.Number -> value.toString() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentRegistry.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentRegistry.kt new file mode 100644 index 0000000000000000000000000000000000000000..e6e1b2c7aa8facdbbd35e3815c1d9bb506792bcf --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/AgentRegistry.kt @@ -0,0 +1,9 @@ +package org.coralprotocol.coralserver.agent.registry + +import kotlinx.serialization.Serializable + +@Serializable +data class AgentRegistry( + val importedAgents: Map, + val exportedAgents: Map +) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryAgent.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryAgent.kt new file mode 100644 index 0000000000000000000000000000000000000000..e7175d9e0c0817954abc5fa67558f222ed38d294 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryAgent.kt @@ -0,0 +1,36 @@ +package org.coralprotocol.coralserver.agent.registry + +import UnresolvedAgentOption +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.agent.runtime.AgentRuntimes + +@Serializable +data class UnresolvedRegistryAgent( + val runtimes: AgentRuntimes, + val options: Map +) { + fun resolve(): RegistryAgent = RegistryAgent( + runtimes = runtimes, + options = options.mapValues { (_, option) -> + option.resolve() + } + ) +} + +@Serializable +data class RegistryAgent( + val runtimes: AgentRuntimes, + val options: Map +) + +@Serializable +data class PublicRegistryAgent( + val id: String, + val options: Map +) + +fun RegistryAgent.toPublic(id: String): PublicRegistryAgent = PublicRegistryAgent( + id = id, + options = options +) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryException.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryException.kt new file mode 100644 index 0000000000000000000000000000000000000000..a481485387e2d3f1dc95b72665b56774809979ff --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryException.kt @@ -0,0 +1,3 @@ +package org.coralprotocol.coralserver.agent.registry + +data class RegistryException(override val message: String?) : Exception(message) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentExport.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentExport.kt new file mode 100644 index 0000000000000000000000000000000000000000..5438c9e022df00120e438e1fd94b94e6eb4e5902 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentExport.kt @@ -0,0 +1,49 @@ +package org.coralprotocol.coralserver.agent.registry + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.util.toUpperCasePreservingASCIIRules +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.agent.runtime.RuntimeId + +val logger = KotlinLogging.logger {} + +@Serializable +data class UnresolvedAgentExport( + val quantity: UInt, + + // resolves to List, maybe make this type a List when ktoml supports lists of enums: + // https://github.com/orchestr7/ktoml/issues/340 + val runtimes: Map + + // todo: pricing here +) { + fun resolve(name: String, agent: RegistryAgent): AgentExport { + if (quantity == 0u) { + throw RegistryException("Cannot export 0 \"$name\" agents") + } + + // Runtimes must be either Executable, Docker or Phala and the runtime must be defined on the imported agent + // todo: disallow exporting of Executable runtimes? + val validRuntimes = runtimes.mapNotNull { (runtimeName, pricing) -> + try { + val runtimeId = RuntimeId.valueOf(runtimeName.toUpperCasePreservingASCIIRules()) + if (agent.runtimes.getById(runtimeId) == null) { + logger.warn { "Runtime \"$runtimeName\" is not defined for agent \"$name\"" } + null + } else { + runtimeId to pricing + } + } + catch (_: IllegalArgumentException) { + logger.warn { "Invalid runtime \"$runtimeName\" for agent \"$name\"" } + null + } + }.toMap() + + if (validRuntimes.isEmpty()) { + throw RegistryException("Cannot export agent \"$name\" with no runtimes") + } + + return AgentExport(agent, validRuntimes, quantity) + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentOption.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentOption.kt new file mode 100644 index 0000000000000000000000000000000000000000..d13fa732f3227f386ca50ec2bcbba83aa69aae98 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentOption.kt @@ -0,0 +1,69 @@ +import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import org.coralprotocol.coralserver.agent.registry.AgentOption +import org.coralprotocol.coralserver.agent.registry.AgentOptionType + +@Serializable +data class UnresolvedAgentOption( + val type: AgentOptionType, + val description: String? = null, + val default: AgentOptionDefault? = null +) { + fun resolve(): AgentOption = + when (type) { + AgentOptionType.STRING -> AgentOption.String(description, + when (default) { + is AgentOptionDefault.Number -> throw IllegalArgumentException("Cannot use number as default for string option") + is AgentOptionDefault.String -> default.value + null -> null + } + ) + AgentOptionType.NUMBER -> AgentOption.Number(description, + when (default) { + is AgentOptionDefault.Number -> default.value + is AgentOptionDefault.String -> throw IllegalArgumentException("Cannot use string as default for number option") + null -> null + }) + AgentOptionType.SECRET -> AgentOption.Secret(description) + } +} + +@Serializable(with = AgentOptionDefaultSerializer::class) +sealed class AgentOptionDefault() { + @Serializable + data class String(val value: kotlin.String) : AgentOptionDefault() + + @Serializable + data class Number(val value: Double) : AgentOptionDefault() +} + +object AgentOptionDefaultSerializer : KSerializer { + override val descriptor = buildClassSerialDescriptor("AgentOptionDefault") + override fun deserialize(decoder: Decoder): AgentOptionDefault { + try { + return AgentOptionDefault.String(decoder.decodeString()) + } + catch (_: Exception) { + + } + + try { + return AgentOptionDefault.Number(decoder.decodeDouble()) + } + catch (_: Exception) { + + } + + throw IllegalArgumentException("Unsupported option format") + } + + override fun serialize(encoder: Encoder, value: AgentOptionDefault) { + when (value) { + is AgentOptionDefault.String -> encoder.encodeString(value.value) + is AgentOptionDefault.Number -> encoder.encodeDouble(value.value) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentRegistry.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentRegistry.kt new file mode 100644 index 0000000000000000000000000000000000000000..6d82bb54126fa4b4f6d25fdf26695eac428adb5c --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedAgentRegistry.kt @@ -0,0 +1,28 @@ +package org.coralprotocol.coralserver.agent.registry + +import com.akuleshov7.ktoml.Toml +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +class UnresolvedAgentRegistry() { + @SerialName("agent-import") + val importedAgents: Map = HashMap() + + @SerialName("agent-export") + val exportedAgents: Map = HashMap() + + fun resolve(toml: Toml): AgentRegistry { + val importedAgents = importedAgents.mapValues { (_, agent) -> + agent.resolve(toml) + } + + return AgentRegistry( + importedAgents, + exportedAgents.mapValues { (name, unresolvedExport) -> + val agent = importedAgents[name] ?: throw RegistryException("Cannot export unknown agent: $name") + unresolvedExport.resolve(name, agent) + }, + ) + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedRegistryAgentReference.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedRegistryAgentReference.kt new file mode 100644 index 0000000000000000000000000000000000000000..65ef212193fedd061494b7f958d8f18706086eeb --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/UnresolvedRegistryAgentReference.kt @@ -0,0 +1,99 @@ +package org.coralprotocol.coralserver.agent.registry + +import com.akuleshov7.ktoml.Toml +import com.akuleshov7.ktoml.source.decodeFromStream +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.serialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import java.nio.file.Path +import kotlin.io.path.inputStream + + +@Serializable(with = UnresolvedRegistryAgentSerializer::class) +sealed class UnresolvedRegistryAgentReference { + /** + * Marketplace agent + */ + @Serializable + data class Marketplace( + val version: String + ) : UnresolvedRegistryAgentReference() + + /** + * Local (on the disk) registry agent + */ + @Serializable + data class Local( + val path: String, + ) : UnresolvedRegistryAgentReference() + + /** + * Agent on a remote git repository + */ + @Serializable + data class Git( + val git: String, + val branch: String? = null, + val tag: String? = null, + val rev: String? = null, + ) : UnresolvedRegistryAgentReference() + + fun resolve(toml: Toml): RegistryAgent { + when (this) { + is Marketplace -> TODO("marketplace agents not supported yet") + is Git -> TODO("git agents not supported yet") + is Local -> { + val agentTomlFile = Path.of(path, "coral-agent.toml") + val agent = toml.decodeFromStream(agentTomlFile.inputStream()) + return agent.resolve() + } + } + } +} + +@OptIn(ExperimentalSerializationApi::class) +object UnresolvedRegistryAgentSerializer : KSerializer { + override val descriptor: SerialDescriptor + get() = SerialDescriptor("UnresolvedRegistryAgent", serialDescriptor()) + + override fun serialize( + encoder: Encoder, + value: UnresolvedRegistryAgentReference + ) { + throw UnsupportedOperationException("Serialization is not supported") + } + + override fun deserialize(decoder: Decoder): UnresolvedRegistryAgentReference { + try { + return decoder.decodeSerializableValue(UnresolvedRegistryAgentReference.Local.serializer()) + } catch (_: Exception) { + + } + + try { + return decoder.decodeSerializableValue(UnresolvedRegistryAgentReference.Git.serializer()) + } catch (_: Exception) { + + } + + try { + /* + * This works for : + * 1) marketplace = "0.1.0" + * and: + * 2) marketplace = { version = "0.1.1" } + * + * ... but I'm not sure if it is designed intentionally to work this way + */ + return UnresolvedRegistryAgentReference.Marketplace(decoder.decodeString()) + } catch (_: Exception) { + + } + + throw IllegalArgumentException("Unsupported agent format") + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/AgentRuntimes.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/AgentRuntimes.kt new file mode 100644 index 0000000000000000000000000000000000000000..0d05c933fb66f62ea955a68d4fa8433049e9a51a --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/AgentRuntimes.kt @@ -0,0 +1,53 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.agent.runtime + +import com.chrynan.uri.core.Uri +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.EventBus +import org.coralprotocol.coralserver.agent.registry.AgentOptionValue +import org.coralprotocol.coralserver.session.SessionManager + +@Serializable +enum class RuntimeId { + @SerialName("executable") + EXECUTABLE, + + @SerialName("docker") + DOCKER +} + +data class RuntimeParams( + val sessionId: String, + val agentName: String, + val mcpServerPort: UShort, + val mcpServerRelativeUri: Uri, + + val systemPrompt: String?, + val options: Map, +) + +@Serializable +@SerialName("runtime") +class AgentRuntimes( + @SerialName("executable") + private val executableRuntime: ExecutableRuntime? = null, + + @SerialName("docker") + private val dockerRuntime: DockerRuntime? = null, +) : Orchestrate { + override fun spawn( + params: RuntimeParams, + eventBus: EventBus, + sessionManager: SessionManager? + ): OrchestratorHandle { + TODO("runtime must be selected") + } + + fun getById(runtimeId: RuntimeId): Orchestrate? = when (runtimeId) { + RuntimeId.EXECUTABLE -> executableRuntime + RuntimeId.DOCKER -> dockerRuntime + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/CoralOrchestratedEnvs.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/CoralOrchestratedEnvs.kt new file mode 100644 index 0000000000000000000000000000000000000000..a2966c3a5c71f30fccae912811c1e5c037c1fff3 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/CoralOrchestratedEnvs.kt @@ -0,0 +1,27 @@ +package org.coralprotocol.coralserver.agent.runtime + +import com.chrynan.uri.core.Uri +import com.chrynan.uri.core.pathSegments + +fun getCoralSystemEnvs( + params: RuntimeParams, + coralConnectionUrl: Uri, + orchestrationRuntime: String +): Map { + // Confirm last segment is "sse" to ensure it's a valid Coral connection URL + if (coralConnectionUrl.pathSegments.isEmpty() || coralConnectionUrl.pathSegments.last() != "sse") { + throw IllegalArgumentException("Coral connection URL must end with '/sse'") + } + val sessionId = coralConnectionUrl.pathSegments.dropLast(1).lastOrNull() + ?: throw IllegalArgumentException("Coral connection URL must contain a session ID in the path") + return listOfNotNull( + "CORAL_CONNECTION_URL" to coralConnectionUrl.toUriString().value, + "CORAL_AGENT_ID" to params.agentName, + "CORAL_ORCHESTRATION_RUNTIME" to orchestrationRuntime, + "CORAL_SESSION_ID" to sessionId, + "CORAL_SSE_URL" to with(coralConnectionUrl) { + "${scheme}://$host:$port$path" + }, + params.systemPrompt?.let { "CORAL_PROMPT_SYSTEM" to it } + ).toMap() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/DockerRuntime.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/DockerRuntime.kt new file mode 100644 index 0000000000000000000000000000000000000000..adb0576f4de3deda25c10d0e99cdc929ea8f56da --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/DockerRuntime.kt @@ -0,0 +1,158 @@ +package org.coralprotocol.coralserver.agent.runtime + +import com.chrynan.uri.core.Uri +import com.chrynan.uri.core.parse +import com.github.dockerjava.api.async.ResultCallback +import com.github.dockerjava.api.exception.NotModifiedException +import com.github.dockerjava.api.model.Frame +import com.github.dockerjava.api.model.StreamType +import com.github.dockerjava.core.DefaultDockerClientConfig +import com.github.dockerjava.core.DockerClientBuilder +import com.github.dockerjava.core.DockerClientConfig +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.EventBus +import org.coralprotocol.coralserver.agent.runtime.executable.EnvVar +import org.coralprotocol.coralserver.session.SessionManager +import java.io.File +import kotlin.time.Duration.Companion.seconds + +private val logger = KotlinLogging.logger {} + +@Serializable +@SerialName("docker") +data class DockerRuntime( + val image: String, + val environment: List = listOf() +) : Orchestrate { + private val dockerClientConfig: DockerClientConfig = DefaultDockerClientConfig.createDefaultConfigBuilder() + .withDockerHost(getDockerSocket()) + .build() + private val dockerClient = DockerClientBuilder.getInstance(dockerClientConfig).build() + + override fun spawn( + params: RuntimeParams, + bus: EventBus, + sessionManager: SessionManager?, + ): OrchestratorHandle { + logger.info { "Spawning Docker container with image: $image" } + val fullConnectionUrl = + "http://host.docker.internal:${params.mcpServerPort}/${params.mcpServerRelativeUri.path}${params.mcpServerRelativeUri.query?.let { "?$it" } ?: ""}" + + val resolvedEnvs = this.environment.map { + val (key, value) = it.resolve(params.options) + "$key=$value" + } + val allEnvs = resolvedEnvs + getCoralSystemEnvs( + params, + Uri.parse(fullConnectionUrl), + "docker" + ).map { (key, value) -> "$key=$value" } + + val containerCreation = dockerClient.createContainerCmd(image) + .withName(getDockerContainerName(params.mcpServerRelativeUri, params.agentName)) + .withEnv(allEnvs) + .withAttachStdout(true) + .withAttachStderr(true) + .withAttachStdin(false) // Stdin makes no sense with orchestration + .exec() + + dockerClient.startContainerCmd(containerCreation.id).exec() + + // Attach to container streams for output redirection + val attachCmd = dockerClient.attachContainerCmd(containerCreation.id) + .withStdOut(true) + .withStdErr(true) + .withFollowStream(true) + .withLogs(true) + + val streamCallback = attachCmd.exec(object : ResultCallback.Adapter() { + override fun onNext(frame: Frame) { + val message = String(frame.payload).trimEnd('\n') + when (frame.streamType) { + StreamType.STDOUT -> { + logger.info { "[STDOUT] ${params.agentName}: $message" } + } + + StreamType.STDERR -> { + logger.info { "[STDERR] ${params.agentName}: $message" } + } + + else -> { + logger.warn { "[UNKNOWN] ${params.agentName}: $message" } + } + } + } + }) + + return object : OrchestratorHandle { + override suspend fun destroy() { + withContext(processContext) { + try { + streamCallback.close() + } catch (e: Exception) { + logger.warn { "Failed to close stream callback: ${e.message}" } + } + + warnOnNotModifiedExceptions { dockerClient.stopContainerCmd(containerCreation.id).exec() } + warnOnNotModifiedExceptions { + withTimeoutOrNull(30.seconds) { + dockerClient.removeContainerCmd(containerCreation.id) + .withRemoveVolumes(true) + .exec() + return@withTimeoutOrNull true + } ?: let { + logger.warn { "Docker container ${params.agentName} did not stop in time, force removing it" } + dockerClient.removeContainerCmd(containerCreation.id) + .withRemoveVolumes(true) + .withForce(true) + .exec() + } + logger.info { "Docker container ${params.agentName} stopped and removed" } + } + } + } + } + } +} + +private suspend fun warnOnNotModifiedExceptions(block: suspend () -> Unit): Unit { + try { + block() + } catch (e: NotModifiedException) { + logger.warn { "Docker operation was not modified: ${e.message}" } + } catch (e: Exception) { + throw e + } +} + +private fun String.asDockerContainerName(): String { + return this.replace(Regex("[^a-zA-Z0-9_]"), "_") + .take(63) // Network-resolvable name limit + .trim('_') +} + +private fun getDockerContainerName(relativeMcpServerUri: Uri, agentName: String): String { + // SessionID is too long for Docker container names, so we use a hash of the URI for deduplication. + val randomSuffix = relativeMcpServerUri.toUriString().hashCode().toString(16).take(11) + return "${agentName.take(52)}_$randomSuffix".asDockerContainerName() +} + +private fun getDockerSocket(): String { + val specifiedSocket = System.getenv("CORAL_DOCKER_SOCKET") + if (specifiedSocket != null) { + return specifiedSocket + } + + // Check whether colima is installed and use its socket if available + val homeDir = System.getProperty("user.home") + val colimaSocket = "$homeDir/.colima/default/docker.sock" + return if (File(colimaSocket).exists()) { + "unix://$colimaSocket" + } else { + "unix:///var/run/docker.sock" // Default Docker socket + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/ExecutableRuntime.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/ExecutableRuntime.kt new file mode 100644 index 0000000000000000000000000000000000000000..2f21fc8620d40214a6c93c2d39dee23c5df74a41 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/ExecutableRuntime.kt @@ -0,0 +1,108 @@ +package org.coralprotocol.coralserver.agent.runtime + +import com.chrynan.uri.core.Uri +import com.chrynan.uri.core.fromParts +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.DelicateCoroutinesApi +import kotlinx.coroutines.newFixedThreadPoolContext +import kotlinx.coroutines.withContext +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.EventBus +import org.coralprotocol.coralserver.models.AgentState +import org.coralprotocol.coralserver.agent.runtime.executable.EnvVar +import org.coralprotocol.coralserver.session.SessionManager +import java.util.concurrent.TimeUnit +import kotlin.collections.iterator +import kotlin.concurrent.thread + +private val logger = KotlinLogging.logger {} + +@Serializable +@SerialName("executable") +data class ExecutableRuntime( + val command: List, + val environment: List = listOf() +) : Orchestrate { + override fun spawn( + params: RuntimeParams, + bus: EventBus, + sessionManager: SessionManager?, + ): OrchestratorHandle { + val processBuilder = ProcessBuilder() + val processEnvironment = processBuilder.environment() + val path = processEnvironment["PATH"] ?: "" + processEnvironment.clear() + processEnvironment["PATH"] = path + + // TODO: error if someone tries passing coral system envs themselves + val coralConnectionUrl = Uri.fromParts( + scheme = "http", + host = "localhost", // Executables run on the same host as the Coral server + port = params.mcpServerPort.toInt(), + path = params.mcpServerRelativeUri.path, + query = params.mcpServerRelativeUri.query + ) + + val resolvedOptions = this.environment.associate { + val (key, value) = it.resolve(params.options); + key to (value ?: "") + } + val envsToSet = resolvedOptions + getCoralSystemEnvs(params, coralConnectionUrl, "executable") + for (env in envsToSet) { + processEnvironment[env.key] = env.value + } + + processBuilder.command(command) + + logger.info { "spawning process..." } + val process = processBuilder.start() + + // TODO (alan): re-evaluate this when it becomes a bottleneck + + thread(isDaemon = true) { + process.waitFor() + bus.emit(RuntimeEvent.Stopped()) + logger.warn {"Process exited for Agent {params.agentName}"}; + sessionManager?.getSession(params.sessionId)?.setAgentState(params.agentName, AgentState.Dead); + } + + thread(isDaemon = true) { + val reader = process.inputStream.bufferedReader() + reader.forEachLine { line -> + run { + bus.emit(RuntimeEvent.Log(kind = LogKind.STDOUT, message = line)) + logger.info { + "[STDOUT] ${params.agentName}: $line" + } + } + } + } + thread(isDaemon = true) { + val reader = process.errorStream.bufferedReader() + reader.forEachLine { line -> + run { + bus.emit(RuntimeEvent.Log(kind = LogKind.STDERR, message = line)) + logger.error { + "[STDERR] ${params.agentName}: $line" + } + } + } + } + + return object : OrchestratorHandle { + override suspend fun destroy() { + withContext(processContext) { + process.destroy() + process.waitFor(30, TimeUnit.SECONDS) + process.destroyForcibly() + logger.info { "Process exited" } + } + } + } + + } +} + +@OptIn(DelicateCoroutinesApi::class) +val processContext = newFixedThreadPoolContext(10, "processContext") \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/Orchestrator.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/Orchestrator.kt new file mode 100644 index 0000000000000000000000000000000000000000..c622c94bd8f9f3cd9f36017f340009e9b15c92b2 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/Orchestrator.kt @@ -0,0 +1,123 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.agent.runtime + +import com.chrynan.uri.core.Uri +import kotlinx.coroutines.* +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator +import org.coralprotocol.coralserver.EventBus +import org.coralprotocol.coralserver.config.ConfigCollection +import org.coralprotocol.coralserver.agent.graph.GraphAgent +import org.coralprotocol.coralserver.agent.graph.GraphAgentProvider +import org.coralprotocol.coralserver.agent.graph.GraphAgentServerSource +import org.coralprotocol.coralserver.session.SessionManager + +enum class LogKind { + STDOUT, + STDERR, +} + +@Serializable +@JsonClassDiscriminator("type") +sealed interface RuntimeEvent { + @Serializable + @SerialName("log") + data class Log( + val timestamp: Long = System.currentTimeMillis(), + val kind: LogKind, + val message: String + ) : RuntimeEvent + + @Serializable + @SerialName("stopped") + data class Stopped(val timestamp: Long = System.currentTimeMillis()): RuntimeEvent +} + +interface Orchestrate { + fun spawn( + params: RuntimeParams, + eventBus: EventBus, + sessionManager: SessionManager?, + ): OrchestratorHandle +} + +interface OrchestratorHandle { + suspend fun destroy() +} + +class Orchestrator( + val app: ConfigCollection = ConfigCollection(null), +) { + private val eventBusses: MutableMap>> = mutableMapOf() + private val handles: MutableList = mutableListOf() + + + fun getBus(sessionId: String, agentId: String): EventBus? = eventBusses[sessionId]?.get(agentId) + + private fun getBusOrCreate(sessionId: String, agentId: String) = eventBusses.getOrPut(sessionId) { + mutableMapOf() + }.getOrPut(agentId) { + EventBus(replay = 512) + } + + fun spawn(sessionId: String, graphAgent: GraphAgent, agentName: String, port: UShort, relativeMcpServerUri: Uri, sessionManager: SessionManager?) { + val params = RuntimeParams( + sessionId = sessionId, + agentName = agentName, + mcpServerPort = port, + mcpServerRelativeUri = relativeMcpServerUri, + systemPrompt = graphAgent.systemPrompt, + options = graphAgent.options + ) + + val agent = app.registry.importedAgents[graphAgent.name] + ?: throw IllegalArgumentException("Cannot spawn unknown agent: ${graphAgent.name}") + + when (val provider = graphAgent.provider) { + is GraphAgentProvider.Local -> { + val runtime = agent.runtimes.getById(provider.runtime) ?: + throw IllegalArgumentException("The requested runtime: ${provider.runtime} is not supported on agent ${graphAgent.name}") + + handles.add(runtime.spawn( + params, + getBusOrCreate(params.sessionId, params.agentName), + sessionManager) + ) + } + is GraphAgentProvider.Remote -> { + val rankedServers = when (provider.serverSource) { + is GraphAgentServerSource.Servers -> { + provider.serverSource.servers.sortedBy { + provider.serverScoring?.getScore(it) ?: 1.0 + } + } + is GraphAgentServerSource.Indexer -> TODO("indexer server source not yet supported") + } + + /* + Workflow: + 1. Iterate over ranked servers (maintaining order!), finding the first that responds to "pings" + 2. Request the agent from the server. If they decline, move to the next server + 3. If they accept: + a. Do payment stuff, if this fails, move to the next server + b. Open WebSocket connection with the server, this WebSocket connection can be treated like a + bus for a process or docker container + c. Tie the life-cycle of an agent to the WebSocket connection and vice versa, again similarly to a + process or container + d. More payment stuff? + 4. If the list of servers is exhausted without having found a suitable server to provide the agent, + an exception should be thrown + */ + + TODO("support remote runtime agents") + } + } + } + + suspend fun destroy(): Unit = coroutineScope { + handles.map { async { it.destroy() } }.awaitAll() + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/executable/EnvVar.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/executable/EnvVar.kt new file mode 100644 index 0000000000000000000000000000000000000000..8ca17ac74e9cf52469ccc7938ce38bb899a148b1 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/executable/EnvVar.kt @@ -0,0 +1,47 @@ +package org.coralprotocol.coralserver.agent.runtime.executable + +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.agent.registry.AgentOptionValue +import org.coralprotocol.coralserver.agent.registry.toStringValue + +@Serializable +data class EnvVar( + val name: String? = null, + val value: String? = null, + val from: String? = null, + + val option: String? = null, +) { + // TODO (alan): bake this validation into the type system + // EnvVar should be a sum type of 'name/from', 'option' & 'name/value' + fun validate() { + if (option != null && (from != null || value != null || name != null)) { + throw IllegalArgumentException("'option' key is shorthand for 'name' & 'from', it must be used on its own") + } + if (name != null && (value == null && from == null)) { + throw IllegalArgumentException("'value' or 'from' must be provided") + } + if (from != null && value != null) { + throw IllegalArgumentException("'from' and 'value' are mutually exclusive") + } + if (name == null && value == null && from == null && option == null) { + throw IllegalArgumentException("Invalid environment variable definition") + } + } + + fun resolve(options: Map): Pair { + if (option != null) { + val opt = options[option] ?: throw IllegalArgumentException("Undefined option '$option'") + return Pair(option, opt.toStringValue()) + } + val name = name ?: throw IllegalArgumentException("name not provided") + if(from != null) { + val opt = options[from] ?: throw IllegalArgumentException("Undefined option '$from'") + return Pair(from, opt.toStringValue()) + } + if(value != null) { + return Pair(name, value) + } + throw IllegalArgumentException("Invalid environment variable definition") + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/config/AppConfig.kt b/src/main/kotlin/org/coralprotocol/coralserver/config/AppConfig.kt new file mode 100644 index 0000000000000000000000000000000000000000..ee8d9182e0c8723a28d3febc04df5ec5e5b03011 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/config/AppConfig.kt @@ -0,0 +1,36 @@ +package org.coralprotocol.coralserver.config + +import kotlinx.serialization.Serializable + + +// TODO: Applications are a work in progress. This is safe to ignore for now. + +/** + * Main application configuration. + */ +@Serializable +data class AppConfig( + val applications: List = emptyList(), + val applicationSource: ApplicationSourceConfig? = null +) + +/** + * Configuration for an application. + */ +@Serializable +data class ApplicationConfig( + val id: String, + val name: String, + val description: String = "", + val privacyKeys: List = emptyList() +) + +/** + * Configuration for application source (for future use). + */ +@Serializable +data class ApplicationSourceConfig( + val type: String, + val url: String? = null, + val refreshIntervalSeconds: Int = 3600 +) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/config/ConfigCollection.kt b/src/main/kotlin/org/coralprotocol/coralserver/config/ConfigCollection.kt new file mode 100644 index 0000000000000000000000000000000000000000..9f5d29700af2a961b50c399e5ddd418563d048b9 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/config/ConfigCollection.kt @@ -0,0 +1,217 @@ +package org.coralprotocol.coralserver.config + +import com.akuleshov7.ktoml.Toml +import com.akuleshov7.ktoml.TomlInputConfig +import com.akuleshov7.ktoml.source.decodeFromStream +import com.charleskorn.kaml.PolymorphismStyle +import com.charleskorn.kaml.Yaml +import com.charleskorn.kaml.YamlConfiguration +import com.charleskorn.kaml.decodeFromStream +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* +import kotlinx.io.files.FileNotFoundException +import org.coralprotocol.coralserver.agent.registry.AgentRegistry +import org.coralprotocol.coralserver.agent.registry.RegistryException +import org.coralprotocol.coralserver.agent.registry.UnresolvedAgentRegistry +import java.nio.file.* +import kotlin.io.path.listDirectoryEntries + +private val logger = KotlinLogging.logger {} + +/** + * Creates a flow WatchEvent from a watchService + */ +fun WatchService.eventFlow(): Flow>> = flow { + while (currentCoroutineContext().isActive) { + coroutineScope { + var key: WatchKey? = null + val job = launch { + runInterruptible(Dispatchers.IO) { + key = take() + } + } + job.join() + val currentKey = key + if (currentKey != null) { + emit(currentKey.pollEvents()) + currentKey.reset() + } + } + } +} + +/** + * Returns a flow with the files inside a folder (with a given glob) + */ +fun Path.listDirectoryEntriesFlow(glob: String): Flow> { + val watchService = watch() + return watchService.eventFlow() + .map { listDirectoryEntries(glob) } + .onStart { emit(listDirectoryEntries(glob)) } + .onCompletion { watchService.close() } + .flowOn(Dispatchers.IO) +} + +/** + * Creates a new WatchService for any Event + */ +fun Path.watch(): WatchService { + return watch( + StandardWatchEventKinds.ENTRY_CREATE, StandardWatchEventKinds.ENTRY_MODIFY, + StandardWatchEventKinds.OVERFLOW, StandardWatchEventKinds.ENTRY_DELETE + ) +} + +/** + * Creates a new watch service + */ +fun Path.watch(vararg events: WatchEvent.Kind) = + fileSystem.newWatchService()!!.apply { register(this, events) } + + +/** + * Loads application configuration from resources. + */ +class ConfigCollection( + val appConfigPath: Path? = getConfigPath("application.yaml"), + val registryPath: Path? = getConfigPath("registry.toml"), + val defaultConfig: AppConfig = AppConfig( + applications = listOf( + ApplicationConfig( + id = "default-app", + name = "Default Application", + description = "Default application (fallback)", + privacyKeys = listOf("default-key", "public") + ) + ) + ), + val defaultRegistry: AgentRegistry = AgentRegistry( + mapOf(), + mapOf() + ) +) { + var appConfig: AppConfig = loadAppConfig(appConfigPath) + private set + + var registry: AgentRegistry = loadRegistry(registryPath) + private set + + private val watchJob: Job? = appConfigPath?.let { + CoroutineScope(Dispatchers.Default).launch { + logger.info{ "Watching for config changes in '${it.parent}'..." } + it.parent.listDirectoryEntriesFlow("application.yaml*").distinctUntilChanged().collect { + logger.info { "application.yaml changed. Reloading..." } + appConfig = loadAppConfig(appConfigPath) + } + } + } + + fun stopWatch() { + watchJob?.cancel() + } + + companion object { + private fun getConfigPath(file: String): Path? { + // Try to load from resources if no config path set + return when (val configPath = System.getenv("CONFIG_PATH")) { + null -> if(Path.of("./$file").toFile().exists()) { + Path.of("./$file") // Check local application.yaml + } else Path.of("./src/main/resources/$file") // Assume running from source when config path not specified + + else -> (Path.of(configPath, file)) + } + } + } + + /** + * Loads the application configuration from the resources. + * If the configuration is already loaded, returns the cached instance. + */ + private fun loadAppConfig(path: Path?): AppConfig = try { + val file = path?.toFile() + if (file != null) { + if (!file.exists()) { + throw FileNotFoundException(file.absolutePath) + } + + val c = + Yaml(configuration = YamlConfiguration(polymorphismStyle = PolymorphismStyle.Property)).decodeFromStream( + file.inputStream() + ) + appConfig = c + + logger.info { "Loaded configuration with ${c.applications.size} applications" } + c + } else { + throw Exception("Failed to lookup resource.") + } + } catch (e: Exception) { + logger.error(e) { "Failed to load configuration, using default" } + defaultConfig + } + + /** + * Loads the agent registry from the specified path. + */ + private fun loadRegistry(path: Path?): AgentRegistry = try { + val file = path?.toFile() + if (file != null) { + if (!file.exists()) { + throw FileNotFoundException(file.absolutePath) + } + + val toml = Toml( + inputConfig = TomlInputConfig( + // allow/prohibit unknown names during the deserialization, default false + ignoreUnknownNames = true, + // allow/prohibit empty values like "a = # comment", default true + allowEmptyValues = true, + // allow/prohibit null values like "a = null", default true + allowNullValues = true, + // allow/prohibit escaping of single quotes in literal strings, default true + allowEscapedQuotesInLiteralStrings = true, + // allow/prohibit processing of empty toml, if false - throws an InternalDecodingException exception, default is true + allowEmptyToml = true, + // allow/prohibit default values during the deserialization, default is false + ignoreDefaultValues = false, + ) + ) + + val reg = toml.decodeFromStream(file.inputStream()) + .resolve(toml) + + logger.info { "Loaded registry with ${reg.importedAgents.size} imported agents and ${reg.exportedAgents.size} exported agents" } + reg + } else { + throw Exception("Failed to load registry file") + } + } + catch (e: RegistryException) { + logger.error{ "Error with registry file: ${e.message}" } + logger.warn{ "Using default registry" } + defaultRegistry + } + catch (e: Exception) { + logger.error(e) { "Unexpected exception loading registry" } + logger.warn{ "Using default registry" } + defaultRegistry + } + + /** + * Validates if the application ID and privacy key are valid. + */ + fun isValidApplication(applicationId: String, privacyKey: String): Boolean { + val application = appConfig.applications.find { it.id == applicationId } + return application != null && application.privacyKeys.contains(privacyKey) + } + + /** + * Gets an application by ID. + */ + fun getApplication(applicationId: String): ApplicationConfig? { + return appConfig.applications.find { it.id == applicationId } + } +} + +fun ConfigCollection.Companion.custom(config: AppConfig) = ConfigCollection(defaultConfig = config) diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcpresources/MessageResource.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcpresources/MessageResource.kt new file mode 100644 index 0000000000000000000000000000000000000000..afd885095e6a5dd94ef3e641b7c4b11dc52ec082 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcpresources/MessageResource.kt @@ -0,0 +1,36 @@ +package org.coralprotocol.coralserver.mcpresources + +import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.TextResourceContents +import nl.adaptivity.xmlutil.QName +import nl.adaptivity.xmlutil.serialization.XML +import org.coralprotocol.coralserver.models.ResolvedThread +import org.coralprotocol.coralserver.models.resolve +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +private fun CoralAgentIndividualMcp.handler(request: ReadResourceRequest): ReadResourceResult { + val threadsAgentPrivyIn: List = this.coralAgentGraphSession.getAllThreadsAgentParticipatesIn(this.connectedAgentId).map { it -> it.resolve() } + val renderedThreads: String = XML.encodeToString(threadsAgentPrivyIn, rootName = QName("threads")) + return ReadResourceResult( + contents = listOf( + TextResourceContents( + text = renderedThreads, + uri = request.uri, + mimeType = "application/xml", + ) + ) + ) +} + +fun CoralAgentIndividualMcp.addMessageResource() { + addResource( + name = "message", + description = "Message resource", + uri = this.connectedUri, + mimeType = "application/json", + readHandler = { request: ReadResourceRequest -> + handler(request) + }, + ) +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/AddParticipantTool.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/AddParticipantTool.kt new file mode 100644 index 0000000000000000000000000000000000000000..51518f6556da753b6db3e01e94af7637c85412e8 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/AddParticipantTool.kt @@ -0,0 +1,71 @@ +package org.coralprotocol.coralserver.mcptools + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +private val logger = KotlinLogging.logger {} + +/** + * Extension function to add the add participant tool to a server. + */ +fun CoralAgentIndividualMcp.addAddParticipantTool() { + addTool( + name = "add_participant", + description = "Add a participant to a thread", + inputSchema = Tool.Input( + properties = buildJsonObject { + putJsonObject("threadId") { + put("type", "string") + put("description", "ID of the thread") + } + putJsonObject("participantId") { + put("type", "string") + put("description", "ID of the agent to add") + } + }, + required = listOf("threadId", "participantId") + ) + ) { request -> + handleAddParticipant(request) + } +} + +/** + * Handles the add participant tool request. + */ +private fun CoralAgentIndividualMcp.handleAddParticipant(request: CallToolRequest): CallToolResult { + try { + val json = Json { ignoreUnknownKeys = true } + val input = json.decodeFromString(request.arguments.toString()) + val success = coralAgentGraphSession.addParticipantToThread( + threadId = input.threadId, + participantId = input.participantId + ) + + if (success) { + return CallToolResult( + content = listOf(TextContent("Participant added successfully to thread ${input.threadId}")) + ) + } else { + val errorMessage = "Failed to add participant: Thread not found, participant not found, or thread is closed" + logger.error { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } + } catch (e: Exception) { + val errorMessage = "Error adding participant: ${e.message}" + logger.error(e) { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/CloseThreadTool.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/CloseThreadTool.kt new file mode 100644 index 0000000000000000000000000000000000000000..ace00c716edfc940689e8d70238752ff1e8a8ac6 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/CloseThreadTool.kt @@ -0,0 +1,71 @@ +package org.coralprotocol.coralserver.mcptools + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +private val logger = KotlinLogging.logger {} + +/** + * Extension function to add the close thread tool to a server. + */ +fun CoralAgentIndividualMcp.addCloseThreadTool() { + addTool( + name = "close_thread", + description = "Close a thread with a summary", + inputSchema = Tool.Input( + properties = buildJsonObject { + putJsonObject("threadId") { + put("type", "string") + put("description", "ID of the thread to close") + } + putJsonObject("summary") { + put("type", "string") + put("description", "Summary of the thread") + } + }, + required = listOf("threadId", "summary") + ) + ) { request -> + handleCloseThread(request) + } +} + +/** + * Handles the close thread tool request. + */ +private fun CoralAgentIndividualMcp.handleCloseThread(request: CallToolRequest): CallToolResult { + try { + val json = Json { ignoreUnknownKeys = true } + val input = json.decodeFromString(request.arguments.toString()) + val success = coralAgentGraphSession.closeThread( + threadId = input.threadId, + summary = input.summary + ) + + if (success) { + return CallToolResult( + content = listOf(TextContent("Thread closed successfully with summary: ${input.summary}")) + ) + } else { + val errorMessage = "Failed to close thread: Thread not found" + logger.error { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } + } catch (e: Exception) { + val errorMessage = "Error closing thread: ${e.message}" + logger.error(e) { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/CreateThreadTool.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/CreateThreadTool.kt new file mode 100644 index 0000000000000000000000000000000000000000..8582f6d7ad4cbe30f596243057ab39e06875b24a --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/CreateThreadTool.kt @@ -0,0 +1,78 @@ +package org.coralprotocol.coralserver.mcptools + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + + +private val logger = KotlinLogging.logger {} + +/** + * Extension function to add the create thread tool to a server. + */ +fun CoralAgentIndividualMcp.addCreateThreadTool() { + addTool( + name = "create_thread", + description = "Create a new thread with a list of participants", + inputSchema = Tool.Input( + properties = buildJsonObject { + putJsonObject("threadName") { + put("type", "string") + put("description", "Name of the thread") + } + putJsonObject("participantIds") { + put("type", "array") + put("description", "List of agent IDs to include as participants") + putJsonObject("items") { + put("type", "string") + } + } + }, + required = listOf("threadName", "participantIds") + ) + ) { request -> + handleCreateThread(request) + } +} + +/** + * Handles the create thread tool request. + */ +private fun CoralAgentIndividualMcp.handleCreateThread(request: CallToolRequest): CallToolResult { + try { + val json = Json { ignoreUnknownKeys = true } + val input = json.decodeFromString(request.arguments.toString()) + val thread = coralAgentGraphSession.createThread( + name = input.threadName, + creatorId = connectedAgentId, + participantIds = input.participantIds + ) + + return CallToolResult( + content = listOf( + TextContent( + """ + |Thread created successfully: + |ID: ${thread.id} + |Name: ${thread.name} + |Creator: ${thread.creatorId} + |Participants: ${thread.participants.joinToString(", ")} + """.trimMargin() + ) + ) + ) + } catch (e: Exception) { + val errorMessage = "Error creating thread: ${e.message}" + logger.error(e) { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ListAgentsTool.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ListAgentsTool.kt new file mode 100644 index 0000000000000000000000000000000000000000..0f5d8a5c6089cae9f59f8c3c31b9d36c94bbd2bb --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ListAgentsTool.kt @@ -0,0 +1,84 @@ +package org.coralprotocol.coralserver.mcptools + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +private val logger = KotlinLogging.logger {} + +/** + * Extension function to add the list agents tool to a server. + */ +fun CoralAgentIndividualMcp.addListAgentsTool() { + addTool( + name = "list_agents", + description = "List all registered agents in your contact.", + inputSchema = Tool.Input( + properties = buildJsonObject { + putJsonObject("includeDetails") { + put("type", "boolean") + put("description", "Whether to include agent details in the response") + } + }, + required = listOf("includeDetails") + ) + ) { request -> + handleListAgents(request) + } +} + +/** + * Handles the list agents tool request. + */ +private fun CoralAgentIndividualMcp.handleListAgents(request: CallToolRequest): CallToolResult { + try { + val json = Json { ignoreUnknownKeys = true } + val input = json.decodeFromString(request.arguments.toString()) + val agents = coralAgentGraphSession.getAllAgents() + + if (agents.isNotEmpty()) { + val agentsList = if (input.includeDetails) { + agents.joinToString("\n") { agent -> + val description = if (agent.description.isNotEmpty()) { + ", Description: ${agent.description}" + } else { + "" + } + "ID: ${agent.id}, $description" + } + } else { + agents.joinToString(", ") { agent -> agent.id } + } + + return CallToolResult( + content = listOf( + TextContent( + """ + Registered Agents (${agents.size}): + $agentsList + """.trimIndent() + ) + ) + ) + } else { + return CallToolResult( + content = listOf(TextContent("No agents are currently registered in the system")) + ) + } + } catch (e: Exception) { + val errorMessage = "Error listing agents: ${e.message}" + logger.error(e) { errorMessage } + + // Return a user-friendly error message + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/RemoveParticipantTool.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/RemoveParticipantTool.kt new file mode 100644 index 0000000000000000000000000000000000000000..7c73056a5a146e16fffdc1e8ca1f1d60fb40bb5b --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/RemoveParticipantTool.kt @@ -0,0 +1,71 @@ +package org.coralprotocol.coralserver.mcptools + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +private val logger = KotlinLogging.logger {} + +/** + * Extension function to add the remove participant tool to a server. + */ +fun CoralAgentIndividualMcp.addRemoveParticipantTool() { + addTool( + name = "remove_participant", + description = "Remove a participant from a thread", + inputSchema = Tool.Input( + properties = buildJsonObject { + putJsonObject("threadId") { + put("type", "string") + put("description", "ID of the thread") + } + putJsonObject("participantId") { + put("type", "string") + put("description", "ID of the agent to remove") + } + }, + required = listOf("threadId", "participantId") + ) + ) { request -> + handleRemoveParticipant(request) + } +} + +/** + * Handles the remove participant tool request. + */ +private fun CoralAgentIndividualMcp.handleRemoveParticipant(request: CallToolRequest): CallToolResult { + try { + val json = Json { ignoreUnknownKeys = true } + val input = json.decodeFromString(request.arguments.toString()) + val success = coralAgentGraphSession.removeParticipantFromThread( + threadId = input.threadId, + participantId = input.participantId + ) + + if (success) { + return CallToolResult( + content = listOf(TextContent("Participant removed successfully from thread ${input.threadId}")) + ) + } else { + val errorMessage = "Failed to remove participant: Thread not found, participant not found, or thread is closed" + logger.error { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } + } catch (e: Exception) { + val errorMessage = "Error removing participant: ${e.message}" + logger.error(e) { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/SendMessageTool.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/SendMessageTool.kt new file mode 100644 index 0000000000000000000000000000000000000000..eed1cdc6650a358ca8d4cd04f98581725a73d8ad --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/SendMessageTool.kt @@ -0,0 +1,91 @@ +package org.coralprotocol.coralserver.mcptools + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +private val logger = KotlinLogging.logger {} + +/** + * Extension function to add the send message tool to a server. + */ +fun CoralAgentIndividualMcp.addSendMessageTool() { + addTool( + name = "send_message", + description = "Send a message to a thread", + inputSchema = Tool.Input( + properties = buildJsonObject { + putJsonObject("threadId") { + put("type", "string") + put("description", "ID of the thread") + } + putJsonObject("content") { + put("type", "string") + put("description", "Content of the message") + } + putJsonObject("mentions") { + put("type", "array") + put("description", "List of agent IDs to mention in the message. You *must* mention an agent for them to be made aware of the message.") + putJsonObject("items") { + put("type", "string") + } + } + }, + required = listOf("threadId", "content", "mentions") + ) + ) { request -> + handleSendMessage(request) + } +} + +/** + * Handles the send message tool request. + */ +private suspend fun CoralAgentIndividualMcp.handleSendMessage(request: CallToolRequest): CallToolResult { + try { + val json = Json { ignoreUnknownKeys = true } + val input = json.decodeFromString(request.arguments.toString()) + val message = coralAgentGraphSession.sendMessage( + threadId = input.threadId, + senderId = this.connectedAgentId, + content = input.content, + mentions = input.mentions + ) + + if (message != null) { + logger.info { message } + + return CallToolResult( + content = listOf( + TextContent( + """ + Message sent successfully: + ID: ${message.id} + Thread: ${message.thread.id} + Sender: ${message.sender.id} + """.trimIndent() + ) + ) + ) + } else { + val errorMessage = "Failed to send message: Thread not found, sender not found, thread is closed, or sender is not a participant" + logger.error { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } + } catch (e: Exception) { + val errorMessage = "Error sending message: ${e.message}" + logger.error(e) { errorMessage } + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ThreadInputModels.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ThreadInputModels.kt new file mode 100644 index 0000000000000000000000000000000000000000..7badb32b55ce859c17963885d92493dbfb30cc3f --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ThreadInputModels.kt @@ -0,0 +1,65 @@ +package org.coralprotocol.coralserver.mcptools + +import kotlinx.serialization.Serializable + +/** + * Tool for creating a new thread. + */ +@Serializable +data class CreateThreadInput( + val threadName: String, + val participantIds: List +) + +/** + * Tool for adding a participant to a thread. + */ +@Serializable +data class AddParticipantInput( + val threadId: String, + val participantId: String +) + +/** + * Tool for removing a participant from a thread. + */ +@Serializable +data class RemoveParticipantInput( + val threadId: String, + val participantId: String +) + +/** + * Tool for closing a thread with a summary. + */ +@Serializable +data class CloseThreadInput( + val threadId: String, + val summary: String +) + +/** + * Tool for sending a message to a thread. + */ +@Serializable +data class SendMessageInput( + val threadId: String, + val content: String, + val mentions: List = emptyList() +) + +/** + * Tool for waiting for new messages mentioning an agent. + */ +@Serializable +data class WaitForMentionsInput( + val timeoutMs: Long = 30000 +) + +/** + * Tool for listing all registered agents. + */ +@Serializable +data class ListAgentsInput( + val includeDetails: Boolean = true // Whether to include agent details in the response +) diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ThreadToolsRegistry.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ThreadToolsRegistry.kt new file mode 100644 index 0000000000000000000000000000000000000000..add3b87f614a193676a336026f1d4ab62a1cbab5 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/ThreadToolsRegistry.kt @@ -0,0 +1,16 @@ +package org.coralprotocol.coralserver.mcptools + +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +/** + * Extension function to add all thread-based tools to a server. + */ +fun CoralAgentIndividualMcp.addThreadTools() { + addListAgentsTool() + addCreateThreadTool() + addAddParticipantTool() + addRemoveParticipantTool() + addCloseThreadTool() + addSendMessageTool() + addWaitForMentionsTool() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/mcptools/WaitForMentionsTool.kt b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/WaitForMentionsTool.kt new file mode 100644 index 0000000000000000000000000000000000000000..6f59613e8442649cec5d504174f483e64b303f31 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/mcptools/WaitForMentionsTool.kt @@ -0,0 +1,86 @@ +package org.coralprotocol.coralserver.mcptools + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject +import nl.adaptivity.xmlutil.serialization.XML +import org.coralprotocol.coralserver.models.AgentState +import org.coralprotocol.coralserver.models.resolve +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +private val logger = KotlinLogging.logger {} + +/** + * Extension function to add the wait for mentions tool to a server. + */ +fun CoralAgentIndividualMcp.addWaitForMentionsTool() { + addTool( + name = "wait_for_mentions", + description = "Wait until mentioned. Call this tool when you're done or want to wait for another agent to respond. This will block until a message is received. You will see all unread messages.", + inputSchema = Tool.Input( + properties = buildJsonObject { + putJsonObject("timeoutMs") { + put("type", "number") + put("description", "Timeout in milliseconds (default: $maxWaitForMentionsTimeoutMs ms). Must be between 0 and $maxWaitForMentionsTimeoutMs ms.") + } + }, + required = listOf("timeoutMs") + ) + ) { request: CallToolRequest -> + handleWaitForMentions(request) + } +} + +/** + * Handles the wait for mentions tool request. + */ +private suspend fun CoralAgentIndividualMcp.handleWaitForMentions(request: CallToolRequest): CallToolResult { + try { + val json = Json { ignoreUnknownKeys = true } + val input = json.decodeFromString(request.arguments.toString()) + logger.info { "Waiting for mentions for agent $connectedAgentId with timeout ${input.timeoutMs}ms" } + if(input.timeoutMs < 0) { + return CallToolResult( + content = listOf(TextContent("Timeout must be greater than 0")) + ) + } + if(input.timeoutMs > maxWaitForMentionsTimeoutMs) { + return CallToolResult( + content = listOf(TextContent("Timeout must not exceed the maximum of $maxWaitForMentionsTimeoutMs ms")) + ) + } + + coralAgentGraphSession.setAgentState(agentId = connectedAgentId, state = AgentState.Listening) + // Use the session to wait for mentions + val messages = coralAgentGraphSession.waitForMentions( + agentId = connectedAgentId, + timeoutMs = input.timeoutMs + ) + + coralAgentGraphSession.setAgentState(agentId = connectedAgentId, state = AgentState.Busy) + if (messages.isNotEmpty()) { + logger.info { "Received ${messages.size} messages for agent $connectedAgentId" } + val formattedMessages = XML.encodeToString (messages.map { message -> message.resolve() }) + return CallToolResult( + content = listOf(TextContent(formattedMessages)) + ) + } else { + return CallToolResult( + content = listOf(TextContent("No new messages received within the timeout period")) + ) + } + } catch (e: Exception) { + val errorMessage = "Error waiting for mentions: ${e.message}" + logger.error(e) { errorMessage } + coralAgentGraphSession.setAgentState(agentId = connectedAgentId, state = AgentState.Busy) + return CallToolResult( + content = listOf(TextContent(errorMessage)) + ) + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/Agent.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/Agent.kt new file mode 100644 index 0000000000000000000000000000000000000000..5a83b92053b0ed740940bfb81066ad8b74d9f6d9 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/Agent.kt @@ -0,0 +1,38 @@ +package org.coralprotocol.coralserver.models + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.coralprotocol.coralserver.session.CustomTool + +/** + * Represents an agent in the system. + */ +// TODO: make Agent a data class, when URI's are implemented +@Serializable +class Agent( + val id: String, + var description: String = "", // Description of the agent's responsibilities + + var state: AgentState = AgentState.Disconnected, + var mcpUrl: String?, + + val extraTools: Set = setOf() +) + +@Serializable +enum class AgentState { + @SerialName("disconnected") + Disconnected, + @SerialName("connecting") + Connecting, + @SerialName("listening") + Listening, + @SerialName("busy") + Busy, + @SerialName("dead") + Dead, +} + +public fun AgentState.isConnected(): Boolean { + return this == AgentState.Listening || this == AgentState.Busy +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/Message.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/Message.kt new file mode 100644 index 0000000000000000000000000000000000000000..b428bd223e8db67cc92b482df4a9826a51f130aa --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/Message.kt @@ -0,0 +1,45 @@ +package org.coralprotocol.coralserver.models + +import java.util.* + +/** + * Represents a message in a thread. + */ +class Message private constructor ( + val id: String = UUID.randomUUID().toString(), + val thread: Thread, + val sender: Agent, + val content: String, + val timestamp: Long = System.currentTimeMillis(), + val mentions: List = emptyList(), + var telemetry: Telemetry? +) { + companion object { + fun create(thread: Thread, sender: Agent, content: String, mentions: List = emptyList()): Message { + if (thread.isClosed) throw IllegalArgumentException("Thread $thread is closed") + + if (!thread.participants.contains(sender.id)) { + throw IllegalArgumentException("Sender agent not a member of thread $thread") + } + + val validMentions = mentions.filter { thread.participants.contains(it) } + return Message ( + thread = thread, + sender = sender, + content = content, + mentions = validMentions, + telemetry = null + ) + } + } +} + +fun Message.resolve(): ResolvedMessage = ResolvedMessage( + id = id, + threadName = this.thread.name, + threadId = this.thread.id, + senderId = this.sender.id, + content = content, + timestamp = timestamp, + mentions = mentions +) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/ResolvedMessage.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/ResolvedMessage.kt new file mode 100644 index 0000000000000000000000000000000000000000..b8ac3e08e137a4ae1d53489efda3a02eb8f62ef6 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/ResolvedMessage.kt @@ -0,0 +1,17 @@ +package org.coralprotocol.coralserver.models + +import kotlinx.serialization.Serializable + +/** + * Represents a message in a thread. + */ +@Serializable +data class ResolvedMessage( + val id: String, + val threadName: String, + val threadId: String, + val senderId: String, + val content: String, + val timestamp: Long, + val mentions: List +) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/SocketEvent.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/SocketEvent.kt new file mode 100644 index 0000000000000000000000000000000000000000..60351bf8e2c4e5eace8946f9666bafb520355e8d --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/SocketEvent.kt @@ -0,0 +1,29 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.models + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator +import org.coralprotocol.coralserver.session.SessionEvent as SessionEvent + +@Serializable +@JsonClassDiscriminator("type") +sealed interface SocketEvent { + @Serializable + @SerialName("debug_agent_registered") + data class DebugAgentRegistered(val id: String) : SocketEvent + + @Serializable + @SerialName("thread_list") + data class ThreadList(val threads: List) : SocketEvent + + @Serializable + @SerialName("agent_list") + data class AgentList(val agents: List) : SocketEvent + + @Serializable + @SerialName("session") + data class Session(val event: SessionEvent) : SocketEvent +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/Telemetry.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/Telemetry.kt new file mode 100644 index 0000000000000000000000000000000000000000..bb8c307517a8a879efcc8870512f5ec6ba1caec2 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/Telemetry.kt @@ -0,0 +1,59 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.models + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator +import kotlinx.serialization.json.JsonIgnoreUnknownKeys +import kotlinx.serialization.json.JsonObject +import org.coralprotocol.coralserver.models.telemetry.generic.Message as GenericMessage +import org.coralprotocol.coralserver.models.telemetry.openai.Message as OpenAIMessage + +@Serializable +@JsonClassDiscriminator("format") +sealed class TelemetryMessages() { + @Suppress("unused") + @Serializable + @SerialName("OpenAI") + data class OpenAI(val data: List) : TelemetryMessages() + + @Suppress("unused") + @Serializable + @SerialName("Generic") + data class Generic(val data: List) : TelemetryMessages() +} + +@Serializable +data class TelemetryTarget( + val threadId: String, + val messageId: String +) + +@Serializable +@JsonIgnoreUnknownKeys +data class Document( + val id: String, + val text: String, + + // This can contain user-defined fields and values +) + +@Serializable +data class Telemetry( + val modelDescription: String, + val preamble: String? = null, + val resources: List, + val tools: List, + val temperature: Double? = null, + val maxTokens: Long? = null, + val additionalParams: JsonObject? = null, + val messages: TelemetryMessages +) + +@Serializable +data class TelemetryPost( + val targets: List, + val data: Telemetry +) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/Thread.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/Thread.kt new file mode 100644 index 0000000000000000000000000000000000000000..f6317b33f91a42c3b6f1b6b5ac5291077c7f66b5 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/Thread.kt @@ -0,0 +1,38 @@ +package org.coralprotocol.coralserver.models + +import kotlinx.serialization.Serializable +import java.util.* + +/** + * Represents a thread with participants. + */ +data class Thread( + val id: String = UUID.randomUUID().toString(), + val name: String, + val creatorId: String, + val participants: MutableList = mutableListOf(), + val messages: MutableList = mutableListOf(), + var isClosed: Boolean = false, + var summary: String? = null +) + +@Serializable +data class ResolvedThread( + val id: String = UUID.randomUUID().toString(), + val name: String, + val creatorId: String, + val participants: List = listOf(), + val messages: List = listOf(), + var isClosed: Boolean = false, + var summary: String? = null +) + +fun Thread.resolve(): ResolvedThread = ResolvedThread( + id = id, + name = name, + creatorId = creatorId, + participants = participants, + messages = messages.map { it.resolve() }, + isClosed = isClosed, + summary = summary +) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Content.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Content.kt new file mode 100644 index 0000000000000000000000000000000000000000..d4ea85654cbe3d4e3986c9b2eabcacc7aee4d4e7 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Content.kt @@ -0,0 +1,232 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.models.telemetry.generic + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator +import kotlinx.serialization.json.JsonIgnoreUnknownKeys + +@Serializable +enum class ContentFormat { + @SerialName("base64") + @Suppress("unused") + BASE64, + + @SerialName("string") + @Suppress("unused") + STRING +} + +@Serializable +enum class ImageDetail { + @SerialName("low") + @Suppress("unused") + LOW, + + @SerialName("high") + @Suppress("unused") + HIGH, + + @SerialName("auto") + @Suppress("unused") + AUTO +} + +@Serializable +enum class ImageMediaType { + @SerialName("jpeg") + @Suppress("unused") + JPEG, + + @SerialName("png") + @Suppress("unused") + PNG, + + @SerialName("gif") + @Suppress("unused") + GIF, + + @SerialName("webp") + @Suppress("unused") + WEBP, + + @SerialName("heic") + @Suppress("unused") + HEIC, + + @SerialName("heif") + @Suppress("unused") + HEIF, + + @SerialName("svg") + @Suppress("unused") + SVG, +} + +@Serializable +enum class DocumentMediaType { + @SerialName("pdf") + @Suppress("unused") + PDF, + + @SerialName("txt") + @Suppress("unused") + TXT, + + @SerialName("rtf") + @Suppress("unused") + RTF, + + @SerialName("html") + @Suppress("unused") + HTML, + + @SerialName("css") + @Suppress("unused") + CSS, + + @SerialName("markdown") + @Suppress("unused") + MARKDOWN, + + @SerialName("csv") + @Suppress("unused") + CSV, + + @SerialName("xml") + @Suppress("unused") + XML, + + @SerialName("javascript") + @Suppress("unused") + JAVASCRIPT, + + @SerialName("python") + @Suppress("unused") + PYTHON, +} + +@Serializable +enum class AudioMediaType { + @SerialName("wav") + @Suppress("unused") + WAV, + + @SerialName("mp3") + @Suppress("unused") + MP3, + + @SerialName("aiff") + @Suppress("unused") + AIFF, + + @SerialName("aac") + @Suppress("unused") + AAC, + + @SerialName("ogg") + @Suppress("unused") + OGG, + + @SerialName("flac") + @Suppress("unused") + FLAC, +} + +@Serializable +enum class VideoMediaType { + @SerialName("avi") + @Suppress("unused") + AVI, + + @SerialName("mp4") + @Suppress("unused") + MP4, + + @SerialName("mpeg") + @Suppress("unused") + MPEG, +} + +@Serializable +@JsonClassDiscriminator("type") +@SerialName("GenericUserContent") +sealed class UserContent { + @Serializable + @SerialName("text") + @Suppress("unused") + data class Text(val text: String): UserContent() + + @Serializable + @SerialName("tool_result") + @Suppress("unused") + data class ToolResult( + val id: String, + val callId: String? = null, + val content: List + ): UserContent() + + @Serializable + @SerialName("image") + @Suppress("unused") + data class Image( + val data: String, + val format: ContentFormat? = null, + val mediaType: ImageMediaType? = null, + val detail: ImageDetail? = null + ): UserContent() + + @Serializable + @SerialName("audio") + @Suppress("unused") + data class Audio( + val data: String, + val format: ContentFormat? = null, + val mediaType: AudioMediaType? = null, + ): UserContent() + + @Serializable + @SerialName("video") + @Suppress("unused") + @JsonIgnoreUnknownKeys + data class Video( + val data: String, + val format: ContentFormat? = null, + val mediaType: VideoMediaType? = null, + ): UserContent() + + @Serializable + @SerialName("document") + @Suppress("unused") + data class Document( + val data: String, + val format: ContentFormat? = null, + val mediaType: DocumentMediaType? = null, + ): UserContent() +} + +@Serializable +@JsonClassDiscriminator("type") +@SerialName("GenericAssistantContent") +sealed class AssistantContent { + @Serializable + @SerialName("assistant_text") + @Suppress("unused") + data class Text(val text: String) : AssistantContent() + + @Serializable + @SerialName("assistant_tool_call") + @Suppress("unused") + data class ToolCall( + val id: String, + val callId: String? = null, + val function: ToolFunction + ) : AssistantContent() + + @Serializable + @SerialName("assistant_reasoning") + @Suppress("unused") + data class Reasoning(val reasoning: List) : AssistantContent() +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Message.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Message.kt new file mode 100644 index 0000000000000000000000000000000000000000..854c8e774597e1bd0d2dde8c4f09a4b3b0e21ea7 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Message.kt @@ -0,0 +1,27 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.models.telemetry.generic + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator + +@Serializable +@JsonClassDiscriminator("role") +@SerialName("GenericMessage") +sealed class Message() { + + @Serializable + @SerialName("user") + @Suppress("unused") + data class UserMessage(val content: List) : Message() + + @Serializable + @SerialName("assistant") + @Suppress("unused") + data class AssistantMessage( + val id: String? = null, + val content: List, + ) : Message() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Tool.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Tool.kt new file mode 100644 index 0000000000000000000000000000000000000000..3be353185f2d00a08d419dc4f501961404cb27dd --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/generic/Tool.kt @@ -0,0 +1,31 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.models.telemetry.generic + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator + +@Serializable +@JsonClassDiscriminator("type") +@SerialName("GenericToolResultContent") +sealed class ToolResultContent { + @Serializable + @SerialName("tool_text") + @Suppress("unused") + data class Text(val text: String): ToolResultContent() + + @Serializable + @SerialName("tool_image") + @Suppress("unused") + data class Image( + val data: String, + val format: ContentFormat? = null, + val mediaType: ImageMediaType? = null, + val detail: ImageDetail? = null + ): ToolResultContent() +} + +@Serializable +data class ToolFunction(val name: String, val arguments: String) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Content.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Content.kt new file mode 100644 index 0000000000000000000000000000000000000000..d084d7893033d88105749ba2039260d56f0ed678 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Content.kt @@ -0,0 +1,65 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.models.telemetry.openai + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator +import org.coralprotocol.coralserver.models.telemetry.generic.AudioMediaType +import org.coralprotocol.coralserver.models.telemetry.generic.ImageDetail + +@Serializable +enum class SystemContentType { + @SerialName("text") + @Suppress("unused") + TEXT, +} + +@Serializable +data class ImageUrl(val url: String, val detail: ImageDetail) + +@Serializable +data class InputAudio(val data: String, val format: AudioMediaType) + +@Serializable +@SerialName("OpenAISystemContent") +data class SystemContent(val type: SystemContentType, val text: String) + +@Serializable +data class AudioAssistant(val id: String); + +@Serializable +@JsonClassDiscriminator("type") +@SerialName("OpenAIUserContent") +sealed class UserContent { + @Serializable + @SerialName("text") + @Suppress("unused") + data class Text(val text: String): UserContent() + + @Serializable + @SerialName("image_url") + @Suppress("unused") + data class Image(val imageUrl: ImageUrl): UserContent() + + @Serializable + @SerialName("audio") + @Suppress("unused") + data class Audio(val inputAudio: InputAudio): UserContent() +} + +@Serializable +@JsonClassDiscriminator("type") +@SerialName("OpenAIAssistantContent") +sealed class AssistantContent { + @Serializable + @SerialName("text") + @Suppress("unused") + data class Text(val text: String) : AssistantContent() + + @Serializable + @SerialName("refusal") + @Suppress("unused") + data class Refusal(val refusal: String) : AssistantContent() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Message.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Message.kt new file mode 100644 index 0000000000000000000000000000000000000000..43a343de2a30970097d844108e74c92a8d9417af --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Message.kt @@ -0,0 +1,42 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.models.telemetry.openai + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator + +@Serializable +@JsonClassDiscriminator("role") +@SerialName("OpenAIMessage") +sealed class Message { + @Serializable + @SerialName("developer") + @Suppress("unused") + data class SystemMessage(val content: List, val name: String?) : Message() + + @Serializable + @SerialName("user") + @Suppress("unused") + data class UserMessage(val content: List, val name: String? = null) : Message() + + @Serializable + @SerialName("assistant") + @Suppress("unused") + data class AssistantMessage( + val content: List, + val refusal: String? = null, + val audio: AudioAssistant? = null, + val name: String? = null, + val toolCalls: List + ) : Message() + + @Serializable + @SerialName("tool") + @Suppress("unused") + data class ToolMessage( + val toolCallId: String, + val content: List + ) : Message() +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Tool.kt b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Tool.kt new file mode 100644 index 0000000000000000000000000000000000000000..bac64bfaeba705b44ddbbd2fcb4a46158782aabb --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/models/telemetry/openai/Tool.kt @@ -0,0 +1,28 @@ +package org.coralprotocol.coralserver.models.telemetry.openai + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +enum class ToolType { + @SerialName("function") + @Suppress("unused") + Function +} + +@Serializable +data class Function(val name: String, val arguments: String) + +@Serializable +data class ToolCall(val id: String, val type: ToolType, val function: Function) + +@Serializable +enum class ToolResultContentType { + @SerialName("text") + @Suppress("unused") + Text +} + +@Serializable +@SerialName("OpenAIToolResultContent") +data class ToolResultContent(val type: ToolResultContentType, val text: String) \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/AgentRoutes.kt b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/AgentRoutes.kt new file mode 100644 index 0000000000000000000000000000000000000000..d7689a2aef6460ee58b6174bb6283e0676574001 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/AgentRoutes.kt @@ -0,0 +1,57 @@ +package org.coralprotocol.coralserver.routes.api.v1 + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.github.smiley4.ktoropenapi.resources.get +import io.ktor.http.HttpStatusCode +import io.ktor.resources.Resource +import io.ktor.server.response.respond +import io.ktor.server.routing.* +import org.coralprotocol.coralserver.config.ConfigCollection +import org.coralprotocol.coralserver.agent.registry.AgentExport +import org.coralprotocol.coralserver.agent.registry.PublicRegistryAgent +import org.coralprotocol.coralserver.agent.registry.toPublic +import org.coralprotocol.coralserver.session.SessionManager + +private val logger = KotlinLogging.logger {} + +@Resource("/api/v1/agents") +class Agents + +@Resource("/api/v1/agents/exported") +class ExportedAgents + +fun Routing.agentApiRoutes(appConfig: ConfigCollection, sessionManager: SessionManager) { + get({ + summary = "Get available agents" + description = "Fetches a list of all agents available to the Coral server" + operationId = "getAvailableAgents" + response { + HttpStatusCode.OK to { + description = "Success" + body> { + description = "List of available agents" + } + } + } + }) { + val agents = appConfig.registry.importedAgents.map { entry -> entry.value.toPublic(entry.key) } + call.respond(HttpStatusCode.OK, agents) + } + + get({ + summary = "Gets exported agents" + description = "Fetches agents the Coral server has exported to other servers" + operationId = "getExportedAgents" + response { + HttpStatusCode.OK to { + description = "Success" + body> { + description = "List of exported agents" + } + } + } + }) { + val agents = appConfig.registry.exportedAgents.map { entry -> entry.value.toPublic(entry.key) } + call.respond(HttpStatusCode.OK, agents) + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/DebugRoutes.kt b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/DebugRoutes.kt new file mode 100644 index 0000000000000000000000000000000000000000..6c00b4156ef4c718573f1d07f30e1ee79d57b17e --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/DebugRoutes.kt @@ -0,0 +1,139 @@ +package org.coralprotocol.coralserver.routes.api.v1 + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.github.smiley4.ktoropenapi.resources.post +import io.ktor.http.* +import io.ktor.resources.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import org.coralprotocol.coralserver.mcptools.CreateThreadInput +import org.coralprotocol.coralserver.mcptools.SendMessageInput +import org.coralprotocol.coralserver.models.resolve +import org.coralprotocol.coralserver.server.RouteException +import org.coralprotocol.coralserver.session.SessionManager + +private val logger = KotlinLogging.logger {} + +@Resource("/api/v1/debug/thread/{applicationId}/{privacyKey}/{coralSessionId}/{debugAgentId}") +class DebugCreateThread( + val applicationId: String, + val privacyKey: String, + val coralSessionId: String, + val debugAgentId: String +) + +@Resource("/api/v1/debug/thread/sendMessage/{applicationId}/{privacyKey}/{coralSessionId}/{debugAgentId}") +class DebugSendMessage( + val applicationId: String, + val privacyKey: String, + val coralSessionId: String, + val debugAgentId: String +) + +fun Routing.debugApiRoutes(sessionManager: SessionManager) { + post({ + summary = "Create thread" + description = "Creates a new thread" + operationId = "debugCreateThread" + request { + pathParameter("applicationId") { + description = "The application ID" + } + pathParameter("privacyKey") { + description = "The privacy key" + } + pathParameter("coralSessionId") { + description = "The Coral session ID" + } + pathParameter("debugAgentId") { + description = "The debug agent ID" + } + body { + description = "Thread creation request" + } + } + response { + HttpStatusCode.OK to { + description = "Thread created successfully" + } + HttpStatusCode.NotFound to { + description = "Session not found" + } + HttpStatusCode.InternalServerError to { + description = "Error creating thread" + } + } + }) { debugRequest -> + // TODO (alan): proper appId/privacyKey based lookups when session manager is updated + val session = sessionManager.getSession(debugRequest.coralSessionId) + ?: throw RouteException(HttpStatusCode.NotFound, "Session not found") + + try { + val request = call.receive() + val thread = session.createThread( + name = request.threadName, + creatorId = debugRequest.debugAgentId, + participantIds = request.participantIds + ) + + call.respond(thread.resolve()) + } catch (e: Exception) { + logger.error(e) { "Error while creating thread" } + call.respond(HttpStatusCode.InternalServerError, "Error: ${e.message}") + } + } + + post({ + summary = "Send message" + description = "Sends a message in debug mode" + operationId = "debugSendMessage" + request { + pathParameter("applicationId") { + description = "The application ID" + } + pathParameter("privacyKey") { + description = "The privacy key" + } + pathParameter("coralSessionId") { + description = "The Coral session ID" + } + pathParameter("debugAgentId") { + description = "The debug agent ID" + } + body { + description = "The message to send" + } + } + response { + HttpStatusCode.OK to { + description = "Message sent successfully" + } + HttpStatusCode.NotFound to { + description = "Session not found" + } + HttpStatusCode.InternalServerError to { + description = "Error sending message" + } + } + }) { debugRequest -> + // TODO (alan): proper appId/privacyKey based lookups when session manager is updated + val session = sessionManager.getSession(debugRequest.coralSessionId) + ?: throw RouteException(HttpStatusCode.NotFound, "Session not found") + + try { + val request = call.receive() + val message = session.sendMessage( + threadId = request.threadId, + senderId = debugRequest.debugAgentId, + content = request.content, + mentions = request.mentions + ) + + call.respond(message.resolve()) + } catch (e: Exception) { + logger.error(e) { "Error while sending message" } + call.respond(HttpStatusCode.InternalServerError, "Error: ${e.message}") + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/DocumentationRoutes.kt b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/DocumentationRoutes.kt new file mode 100644 index 0000000000000000000000000000000000000000..7318c51f359bfcf7e264f13e450fc129c8724d0c --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/DocumentationRoutes.kt @@ -0,0 +1,47 @@ +package org.coralprotocol.coralserver.routes.api.v1 + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.github.smiley4.ktoropenapi.resources.get +import io.ktor.http.* +import io.ktor.resources.* +import io.ktor.server.html.* +import io.ktor.server.routing.* +import kotlinx.html.* + +private val logger = KotlinLogging.logger {} + +@Resource("/v1/docs") +class Documentation + +fun Routing.documentationApiRoutes() { + get({ + hidden = true + }) { + call.respondHtml(HttpStatusCode.OK) { + head { + title("Scalar API Reference") + meta(charset = "utf-8") + meta(name = "viewport", content = "width=device-width, initial-scale=1") + } + body { + div { + id = "app" + } + + // Load the Script + script(src = "https://cdn.jsdelivr.net/npm/@scalar/api-reference") {} + + // Initialize the Scalar API Reference + script { + unsafe { + raw(""" + Scalar.createApiReference('#app', { + url: '/api_v1.json', + }) + """.trimIndent()) + } + } + } + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/MessageRoutes.kt b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/MessageRoutes.kt new file mode 100644 index 0000000000000000000000000000000000000000..4fd463ff78d814c6a68cec5e42abc2b31a48a89b --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/MessageRoutes.kt @@ -0,0 +1,156 @@ +package org.coralprotocol.coralserver.routes.api.v1 + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.github.smiley4.ktoropenapi.resources.post +import io.ktor.http.* +import io.ktor.resources.* +import io.ktor.server.response.* +import io.ktor.server.routing.Routing +import io.ktor.util.collections.* +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport +import org.coralprotocol.coralserver.server.RouteException +import org.coralprotocol.coralserver.session.SessionManager +import kotlin.NoSuchElementException +import kotlin.String + +private val logger = KotlinLogging.logger {} + +@Resource("/api/v1/message/{applicationId}/{privacyKey}/{coralSessionId}") +class Message(val applicationId: String, val privacyKey: String, val coralSessionId: String) + +@Resource("/api/v1/message/devmode/{applicationId}/{privacyKey}/{coralSessionId}") +class DevModeMessage(val applicationId: String, val privacyKey: String, val coralSessionId: String) + +/** + * Configures message-related routes. + * + * @param servers A concurrent map to store server instances by transport session ID + */ +fun Routing.messageApiRoutes(servers: ConcurrentMap, sessionManager: SessionManager) { + // Message endpoint with application, privacy key, and session parameters + post({ + summary = "Send message" + description = "Sends a message" + operationId = "sendMessage" + request { + pathParameter("applicationId") { + description = "The application ID" + } + pathParameter("privacyKey") { + description = "The privacy key" + } + pathParameter("coralSessionId") { + description = "The Coral session ID" + } + } + response { + HttpStatusCode.OK to { + description = "Success" + } + HttpStatusCode.Forbidden to { + description = "Invalid application ID or privacy key" + } + HttpStatusCode.NotFound to { + description = "Invalid Coral session ID" + } + HttpStatusCode.BadRequest to { + description = "Invalid session ID" + } + HttpStatusCode.InternalServerError to { + description = "MCP error" + } + } + }) { message -> + logger.debug { "Received Message" } + + val session = sessionManager.getSession(message.coralSessionId) + if (session == null) { + call.respond(HttpStatusCode.NotFound, "Session not found") + return@post + } + + // Validate that the application and privacy key match the session + if (session.applicationId != message.applicationId || session.privacyKey != message.privacyKey) { + call.respond(HttpStatusCode.Forbidden, "Invalid application ID or privacy key for this session") + return@post + } + + // Get the transport + val transportId = call.request.queryParameters["sessionId"] + ?: throw RouteException(HttpStatusCode.BadRequest, "sessionId missing") + + val transport = servers[transportId]?.transport as? SseServerTransport + if (transport == null) { + call.respond(HttpStatusCode.BadRequest, "Transport not found") + return@post + } + + // Handle the message + try { + transport.handlePostMessage(call) + } catch (e: NoSuchElementException) { + logger.error(e) { "This error likely comes from an inspector or non-essential client and can probably be ignored. See https://github.com/modelcontextprotocol/kotlin-sdk/issues/7" } + call.respond(HttpStatusCode.InternalServerError, "Error handling message: ${e.message}") + } + } + + // DevMode message endpoint - no validation + post({ + summary = "Send development message" + description = "Sends a dev-mode message" + operationId = "sendDevMessage" + request { + pathParameter("applicationId") { + description = "The application ID" + } + pathParameter("privacyKey") { + description = "The privacy key" + } + pathParameter("coralSessionId") { + description = "The Coral session ID" + } + } + response { + HttpStatusCode.OK to { + description = "Success" + } + HttpStatusCode.NotFound to { + description = "Invalid Coral session ID" + } + HttpStatusCode.BadRequest to { + description = "Invalid session ID" + } + HttpStatusCode.InternalServerError to { + description = "MCP error" + } + } + }) { message -> + logger.debug { "Received DevMode Message" } + + // Get the session. It should exist even in dev mode as it was created in the sse endpoint + val session = sessionManager.getSession(message.coralSessionId) + if (session == null) { + call.respond(HttpStatusCode.NotFound, "Session not found") + return@post + } + + // Get the transport + val transportId = call.request.queryParameters["sessionId"] + ?: throw RouteException(HttpStatusCode.BadRequest, "sessionId missing") + + val transport = servers[transportId]?.transport as? SseServerTransport + if (transport == null) { + call.respond(HttpStatusCode.BadRequest, "Transport not found") + return@post + } + + // Handle the message + try { + transport.handlePostMessage(call) + } catch (e: NoSuchElementException) { + logger.error(e) { "This error likely comes from an inspector or non-essential client and can probably be ignored. See https://github.com/modelcontextprotocol/kotlin-sdk/issues/7" } + call.respond(HttpStatusCode.InternalServerError, "Error handling message: ${e.message}") + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/SessionRoutes.kt b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/SessionRoutes.kt new file mode 100644 index 0000000000000000000000000000000000000000..5e8d1f888169df1a7c35d640a8b1fb619bdaa9f3 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/SessionRoutes.kt @@ -0,0 +1,180 @@ +package org.coralprotocol.coralserver.routes.api.v1 + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.github.smiley4.ktoropenapi.resources.get +import io.github.smiley4.ktoropenapi.resources.post +import io.ktor.http.* +import io.ktor.resources.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import org.coralprotocol.coralserver.agent.graph.AgentGraph +import org.coralprotocol.coralserver.agent.graph.GraphAgent +import org.coralprotocol.coralserver.agent.graph.GraphAgentProvider +import org.coralprotocol.coralserver.config.ConfigCollection +import org.coralprotocol.coralserver.agent.registry.defaultAsValue +import org.coralprotocol.coralserver.agent.runtime.RuntimeId +import org.coralprotocol.coralserver.server.RouteException +import org.coralprotocol.coralserver.session.* + +private val logger = KotlinLogging.logger {} + +@Suppress("UNCHECKED_CAST") +fun Map.filterNotNullValues(): Map = + filterValues { it != null } as Map + +@Resource("/api/v1/sessions") +class Sessions + +/** + * Configures session-related routes. + */ +fun Routing.sessionApiRoutes(appConfig: ConfigCollection, sessionManager: SessionManager, devMode: Boolean) { + post({ + summary = "Create session" + description = "Creates a new session" + operationId = "createSession" + request { + body { + description = "Session creation request" + } + } + response { + HttpStatusCode.OK to { + description = "Success" + body { + description = "Session details" + } + } + HttpStatusCode.Forbidden to { + description = "Invalid application ID or privacy key" + } + HttpStatusCode.BadRequest to { + description = "The agent graph is invalid and could not be processed" + } + } + }) { + val request = call.receive() + + if (!devMode && !appConfig.isValidApplication(request.applicationId, request.privacyKey)) { + throw RouteException(HttpStatusCode.Forbidden, "Invalid App ID or privacy key") + } + + val agentGraph = request.agentGraph?.let { it -> + val requestedAgents = it.agents + val registry = appConfig.registry + + val missingAgentLinks = it.links.map { set -> + set.filter { agent -> !it.agents.containsKey(agent) } + }.flatten() + + if (missingAgentLinks.isNotEmpty()) { + throw RouteException(HttpStatusCode.BadRequest, + "Links contained agents that are not in the request: ${missingAgentLinks.joinToString()}") + } + + val missingAgents = it.agents.filter { (name, _) -> !registry.importedAgents.containsKey(name) } + if (missingAgents.isNotEmpty()) { + throw RouteException(HttpStatusCode.BadRequest, + "Requested agents not found: ${missingAgentLinks.joinToString()}") + } + + AgentGraph( + tools = it.tools, + links = it.links, + agents = requestedAgents.mapValues { (name, request) -> + // The badAgents check above will ensure this is never null + // + // It'd be more idiomatic to throw the exception here, but the error is nicer when it + // contains all the missing agents in the graph, not just the first one + val agent = registry.importedAgents[name]!! + + val missingRequiredOptions = agent.options.filter { option -> + option.value.required && !request.options.containsKey(option.key) + } + if (missingRequiredOptions.isNotEmpty()) { + throw RouteException( + HttpStatusCode.BadRequest, + "Agent '${name}' is missing required options: ${missingRequiredOptions.keys.joinToString()}" + ) + } + + val missingAgentOptions = request.options.filter { + !agent.options.containsKey(it.key) + } + if (missingAgentOptions.isNotEmpty()) { + throw RouteException( + HttpStatusCode.BadRequest, + "Agent '${name}' contains non-existent options: ${missingRequiredOptions.keys.joinToString()}" + ) + } + + val defaultOptions = + agent.options.mapValues { option -> option.value.defaultAsValue() } + .filterNotNullValues() + + GraphAgent( + name, + blocking = request.blocking ?: true, + extraTools = request.tools, + systemPrompt = request.systemPrompt, + options = defaultOptions + request.options, + provider = request.provider + ) + } + ) + } + + // TODO(alan): actually limit agent communicating using AgentGraph.links + // Create a new session + val session = when (request.sessionId != null && devMode) { + true -> { + try { + sessionManager.createSessionWithId( + request.sessionId, + request.applicationId, + request.privacyKey, + agentGraph + ) + } + catch (e: Exception) { + // TODO: An exception should be made for + throw e + } + } + + false -> { + sessionManager.createSession(request.applicationId, request.privacyKey, agentGraph) + } + } + + // Return the session details + call.respond( + CreateSessionResponse( + sessionId = session.id, + applicationId = session.applicationId, + privacyKey = session.privacyKey + ) + ) + + logger.info { "Created new session ${session.id} for application ${session.applicationId}" } + } + + // TODO: this should probably be protected (only for debug maybe) + get({ + summary = "Get sessions" + description = "Fetches all active session IDs" + operationId = "getSessions" + response { + HttpStatusCode.OK to { + description = "Success" + body> { + description = "List of session IDs" + } + } + } + }) { + val sessions = sessionManager.getAllSessions() + call.respond(HttpStatusCode.OK, sessions.map { it.id }) + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/Telemetry.kt b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/Telemetry.kt new file mode 100644 index 0000000000000000000000000000000000000000..d0c481d0fa3a8b1e69a716b3df1f878f58780843 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/routes/api/v1/Telemetry.kt @@ -0,0 +1,107 @@ +package org.coralprotocol.coralserver.routes.api.v1 + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.github.smiley4.ktoropenapi.resources.get +import io.github.smiley4.ktoropenapi.resources.post +import io.ktor.http.* +import io.ktor.resources.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import org.coralprotocol.coralserver.models.Message +import org.coralprotocol.coralserver.models.Telemetry +import org.coralprotocol.coralserver.server.RouteException +import org.coralprotocol.coralserver.session.SessionManager +import org.coralprotocol.coralserver.models.TelemetryPost as TelemetryPostModel + +private val logger = KotlinLogging.logger {} + +@Resource("/api/v1/telemetry/{sessionId}/{threadId}/{messageId}") +class TelemetryGet(val sessionId: String, val threadId: String, val messageId: String) { + fun intoMessage(sessionManager: SessionManager): Message { + val session = sessionManager.getSession(sessionId) ?: throw RouteException( + HttpStatusCode.NotFound, + "Session not found" + ) + + val thread = session.getThread(threadId) ?: throw RouteException( + HttpStatusCode.NotFound, + "Thread not found" + ) + + // TODO: messages should be a map (@Caelum told me to do this (the bad code not the comment)) + return thread.messages.find { it.id == messageId } ?: throw RouteException( + HttpStatusCode.NotFound, + "Message not found" + ) + } +} + +@Resource("/api/v1/telemetry/{sessionId}") +class TelemetryPost(val sessionId: String) + +fun Routing.telemetryApiRoutes(sessionManager: SessionManager) { + get({ + summary = "Get telemetry" + description = "Fetches telemetry information for a given message" + operationId = "getTelemetry" + request { + pathParameter("sessionId") { + description = "The session ID" + } + pathParameter("threadId") { + description = "The thread ID" + } + pathParameter("messageId") { + description = "The message ID" + } + } + response { + HttpStatusCode.OK to { + description = "Success" + body { + description = "Telemetry data" + } + } + HttpStatusCode.NotFound to { + description = "Telemetry data not found for specified message" + } + } + }) { telemetry -> + call.respond(telemetry.intoMessage(sessionManager).telemetry ?: throw RouteException( + HttpStatusCode.NotFound, + "Telemetry not found" + )) + } + + post({ + summary = "Add telemetry" + description = "Attaches telemetry information a list of messages" + operationId = "addTelemetry" + request { + pathParameter("sessionId") { + description = "The session ID" + } + body { + description = "Telemetry data" + } + } + response { + HttpStatusCode.OK to { + description = "Success" + } + } + }) { post -> + val model = call.receive() + for (target in model.targets) { + val message = TelemetryGet(post.sessionId, target.threadId, target.messageId) + .intoMessage(sessionManager) + + // maybe error if there is telemetry on this message already? + message.telemetry = model.data + logger.info { "Adding telemetry to ${target.threadId}/${message.id} in session \"${post.sessionId}\"" } + } + + call.respond(status = HttpStatusCode.OK, "") + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/routes/sse/v1/ConnectionRoutes.kt b/src/main/kotlin/org/coralprotocol/coralserver/routes/sse/v1/ConnectionRoutes.kt new file mode 100644 index 0000000000000000000000000000000000000000..33467d8f340ffde2cc2e4d4ffbce3987cbe8620c --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/routes/sse/v1/ConnectionRoutes.kt @@ -0,0 +1,189 @@ +package org.coralprotocol.coralserver.routes.sse.v1 + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.http.* +import io.ktor.server.request.host +import io.ktor.server.request.port +import io.ktor.server.request.uri +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.sse.* +import io.ktor.util.collections.* +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp +import org.coralprotocol.coralserver.session.SessionManager + +private val logger = KotlinLogging.logger {} + +/** + * Configures SSE-related routes that handle initial client connections. + * These endpoints establish bidirectional communication channels and must be hit + * before any message processing can begin. + */ +fun Routing.connectionSseRoutes(servers: ConcurrentMap, sessionManager: SessionManager) { + suspend fun ServerSSESession.handleSseConnection(isDevMode: Boolean = false) { + handleSseConnection( + "coral://" + call.request.host() + ":" + call.request.port() + call.request.uri, + call.parameters, + this, + servers, + sessionManager = sessionManager, + isDevMode + ) + } + + sse("/sse/v1/{applicationId}/{privacyKey}/{coralSessionId}") { + handleSseConnection() + } + + sse("/sse/v1/devmode/{applicationId}/{privacyKey}/{coralSessionId}") { + handleSseConnection(true) + } + + /* + The following routes are added as aliases for any piece of existing software that requires that the URL ends + with /sse + */ + + sse("/sse/v1/{applicationId}/{privacyKey}/{coralSessionId}/sse") { + handleSseConnection() + } + + sse("/sse/v1/devmode/{applicationId}/{privacyKey}/{coralSessionId}/sse") { + handleSseConnection(true) + } +} + +/** + * Centralizes SSE connection handling for both production and development modes. + * Dev mode skips validation and allows on-demand session creation for testing, + * while production enforces security checks and requires pre-created sessions. + */ +private suspend fun handleSseConnection( + uri: String, + parameters: Parameters, + sseProducer: ServerSSESession, + servers: ConcurrentMap, + sessionManager: SessionManager, + isDevMode: Boolean +): Boolean { + val applicationId = parameters["applicationId"] + val privacyKey = parameters["privacyKey"] + val sessionId = parameters["coralSessionId"] + val agentId = parameters["agentId"] + val agentDescription: String = parameters["agentDescription"] ?: agentId ?: "no description" + val maxWaitForMentionsTimeout = parameters["maxWaitForMentionsTimeout"]?.toLongOrNull() ?: 60000 + + if (agentId == null) { + sseProducer.call.respond(HttpStatusCode.BadRequest, "Missing agentId parameter") + return false + } + + if (applicationId == null || privacyKey == null || sessionId == null) { + sseProducer.call.respond(HttpStatusCode.BadRequest, "Missing required parameters") + return false + } + + val session = if (isDevMode) { + val waitForAgents = sseProducer.call.request.queryParameters["waitForAgents"]?.toIntOrNull() ?: 0 + val createdSession = sessionManager.getOrCreateSession(sessionId, applicationId, privacyKey, null) + + if (waitForAgents > 0) { + createdSession.devRequiredAgentStartCount = waitForAgents + logger.info { "DevMode: Setting waitForAgents=$waitForAgents for session $sessionId" } + } + + createdSession + } else { + val existingSession = sessionManager.getSession(sessionId) + if (existingSession == null) { + sseProducer.call.respond(HttpStatusCode.NotFound, "Session not found") + return false + } + + if (existingSession.applicationId != applicationId || existingSession.privacyKey != privacyKey) { + sseProducer.call.respond(HttpStatusCode.Forbidden, "Invalid application ID or privacy key for this session") + return false + } + + existingSession + } + val currentCount = session.getRegisteredAgentsCount() + + // TODO: better route err handling + val agent = try { + val agent = when (isDevMode) { + true -> { + val agent = session.registerAgent(agentId, uri, agentDescription, force = true) + session.connectAgent(agentId) + agent!! // never null when force = true + } + false -> { + val agent = session.connectAgent(agentId) + if(agent != null) { + agent.description = agentDescription + agent.mcpUrl = uri + } + agent + } + } + if (agent == null) { + logger.info { "Agent ID $agentId does not exist!" } + sseProducer.call.respond(HttpStatusCode.NotFound, "Agent ID does not exist") + return false + } + agent + } catch (e: Exception) { + logger.info { "Agent ID $agentId already connected!" } + sseProducer.call.respond(HttpStatusCode.BadRequest, "Agent ID already connected") + return false + } + + logger.info { "DevMode: Current agent count for session ${session.id} (object id: ${session}) (from sessionmanager: ${sessionManager}): $currentCount, waiting for: ${session.devRequiredAgentStartCount}" } + val newCount = session.getRegisteredAgentsCount() + logger.info { "DevMode: New agent count for session ${session.id} (object id: ${session})after registering: $newCount" } + + val routeSuffix = if (isDevMode) "devmode/" else "" + val endpoint = "/api/v1/message/$routeSuffix$applicationId/$privacyKey/$sessionId" + val transport = SseServerTransport(endpoint, sseProducer) + + val individualServer = + CoralAgentIndividualMcp(uri, transport, session, agentId, maxWaitForMentionsTimeout, extraTools = agent.extraTools) + session.coralAgentConnections.add(individualServer) + + val transportSessionId = transport.sessionId + servers[transportSessionId] = individualServer + + val success = session.waitForGroup(agentId, 60000) + if (success) { + logger.info { "Agent $agentId successfully waited for group" } + } else { + logger.warn { "Agent $agentId failed waiting for group, proceeding anyway.." } + } + + if (isDevMode) { + logger.info { "DevMode: Connected to session $sessionId with application $applicationId (waitForAgents=${session.devRequiredAgentStartCount})" } + + if (session.devRequiredAgentStartCount > 0) { + if (newCount < session.devRequiredAgentStartCount) { + + val success = session.waitForAgentCount(session.devRequiredAgentStartCount, 60000) + if (success) { + logger.info { "DevMode: Successfully waited for ${session.devRequiredAgentStartCount} agents to connect" } + } else { + logger.warn { "DevMode: Timeout waiting for ${session.devRequiredAgentStartCount} agents to connect, proceeding anyway" } + } + } else { + logger.info { "DevMode: Required agent count already reached" } + } + } + } + + individualServer.connect(transport) + individualServer.onClose { + logger.info { "Agent $agentId disconnected via server." } + session.disconnectAgent(agentId); + } + return true +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/routes/ws/v1/DebugRoutes.kt b/src/main/kotlin/org/coralprotocol/coralserver/routes/ws/v1/DebugRoutes.kt new file mode 100644 index 0000000000000000000000000000000000000000..3f7128b52c23499cf2643f13d73fb0ae22857d37 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/routes/ws/v1/DebugRoutes.kt @@ -0,0 +1,57 @@ +package org.coralprotocol.coralserver.routes.ws.v1 + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.http.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.websocket.* +import org.coralprotocol.coralserver.models.SocketEvent +import org.coralprotocol.coralserver.models.resolve +import org.coralprotocol.coralserver.session.SessionManager + +private val logger = KotlinLogging.logger {} + +fun Routing.debugWsRoutes(sessionManager: SessionManager) { + webSocket("/ws/v1/debug/{applicationId}/{privacyKey}/{coralSessionId}/") { + val applicationId = call.parameters["applicationId"] + val privacyKey = call.parameters["privacyKey"] + // TODO (alan): proper appId/privacyKey based lookups when session manager is updated + val sessionId = call.parameters["coralSessionId"] ?: throw IllegalArgumentException("Missing sessionId") + + val timeout = call.parameters["timeout"]?.toLongOrNull() ?: 1000 + + val session = sessionManager.waitForSession(sessionId, timeout); + if (session == null) { + call.respond(HttpStatusCode.NotFound, "Session not found") + return@webSocket + } + + val debugId = session.registerDebugAgent() + sendSerialized(SocketEvent.DebugAgentRegistered(id = debugId.id)) + + sendSerialized(SocketEvent.ThreadList(session.getAllThreads().map { it.resolve() })) + sendSerialized(SocketEvent.AgentList(session.getAllAgents(false))) + + session.events.collect { evt -> + logger.debug { "Received evt: $evt" } + sendSerialized(SocketEvent.Session(evt)) + } + } + + webSocket("/ws/v1/debug/{applicationId}/{privacyKey}/{coralSessionId}/{agentId}/logs") { + val applicationId = call.parameters["applicationId"] ?: throw IllegalArgumentException("Missing applicationId") + val privacyKey = call.parameters["privacyKey"] ?: throw IllegalArgumentException("Missing privacyKey") + val sessionId = call.parameters["coralSessionId"] ?: throw IllegalArgumentException("Missing sessionId") + val agentId = call.parameters["agentId"] ?: throw IllegalArgumentException("Missing agentId") + + val bus = sessionManager.orchestrator.getBus(sessionId, agentId) ?: run { + call.respond(HttpStatusCode.NotFound, "Agent not found") + return@webSocket; + }; + + bus.events.collect { evt -> + logger.debug { "Received evt: $evt" } + sendSerialized(evt) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/server/CoralAgentIndividualMcp.kt b/src/main/kotlin/org/coralprotocol/coralserver/server/CoralAgentIndividualMcp.kt new file mode 100644 index 0000000000000000000000000000000000000000..5acb540aaa5c2e635b565e513b30b1d9e452966e --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/server/CoralAgentIndividualMcp.kt @@ -0,0 +1,63 @@ +package org.coralprotocol.coralserver.server + +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport +import org.coralprotocol.coralserver.mcpresources.addMessageResource +import org.coralprotocol.coralserver.session.CoralAgentGraphSession +import org.coralprotocol.coralserver.mcptools.addThreadTools +import org.coralprotocol.coralserver.session.CustomTool +import org.coralprotocol.coralserver.session.addExtraTool + +/** + * Represents a persistent connection to a Coral agent. + * Each agent instance has a unique MCP server instance assigned to it. + * + * CoralSession + * + * This [CoralAgentIndividualMcp] should persist even if the agent reconnects via a different transport. + */ +class CoralAgentIndividualMcp( + val connectedUri: String, + /** + * The latest transport used by the agent to connect to the server. It might change if the agent reconnects. + */ + var latestTransport: SseServerTransport, + /** + * The session this agent is part of. + */ + val coralAgentGraphSession: CoralAgentGraphSession, + /** + * The ID of the agent associated with this connection. + */ + val connectedAgentId: String, + val maxWaitForMentionsTimeoutMs: Long = 2000, + val extraTools: Set = setOf(), +) : Server( + Implementation( + name = "Coral Server", + version = "0.1.0" + ), + ServerOptions( + capabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts(listChanged = true), + resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), + tools = ServerCapabilities.Tools(listChanged = true), + ) + ), +) { + init { + addThreadTools() + addMessageResource() + extraTools.forEach { + addExtraTool(coralAgentGraphSession.id, connectedAgentId, it) + } + } + + suspend fun closeTransport() { + latestTransport.close() + } +} + diff --git a/src/main/kotlin/org/coralprotocol/coralserver/server/CoralServer.kt b/src/main/kotlin/org/coralprotocol/coralserver/server/CoralServer.kt new file mode 100644 index 0000000000000000000000000000000000000000..f3b75aaae4cecc5f2a4a273d6f9ff27f54690fb5 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/server/CoralServer.kt @@ -0,0 +1,245 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.server + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.github.smiley4.ktoropenapi.OpenApi +import io.github.smiley4.ktoropenapi.config.OutputFormat +import io.github.smiley4.ktoropenapi.config.SchemaGenerator +import io.github.smiley4.ktoropenapi.openApi +import io.github.smiley4.ktoropenapi.route +import io.github.smiley4.schemakenerator.core.CoreSteps.addMissingSupertypeSubtypeRelations +import io.github.smiley4.schemakenerator.core.CoreSteps.handleNameAnnotation +import io.github.smiley4.schemakenerator.serialization.SerializationSteps.addJsonClassDiscriminatorProperty +import io.github.smiley4.schemakenerator.serialization.SerializationSteps.analyzeTypeUsingKotlinxSerialization +import io.github.smiley4.schemakenerator.serialization.SerializationSteps.renameMembers +import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.compileReferencingRoot +import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.customizeTypes +import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.generateSwaggerSchema +import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.handleCoreAnnotations +import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.handleSchemaAnnotations +import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.mergePropertyAttributesIntoType +import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.withTitle +import io.github.smiley4.schemakenerator.swagger.TitleBuilder +import io.github.smiley4.schemakenerator.swagger.data.TitleType +import io.ktor.http.* +import io.ktor.resources.Resource +import io.ktor.serialization.kotlinx.* +import io.ktor.serialization.kotlinx.json.* +import io.ktor.server.application.* +import io.ktor.server.cio.* +import io.ktor.server.engine.* +import io.ktor.server.plugins.contentnegotiation.* +import io.ktor.server.plugins.cors.routing.* +import io.ktor.server.plugins.statuspages.* +import io.ktor.server.resources.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.sse.* +import io.ktor.server.websocket.* +import io.ktor.util.collections.* +import io.modelcontextprotocol.kotlin.sdk.server.Server +import kotlinx.coroutines.Job +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonNamingStrategy +import org.coralprotocol.coralserver.config.ConfigCollection +import org.coralprotocol.coralserver.models.SocketEvent +import org.coralprotocol.coralserver.routes.api.v1.debugApiRoutes +import org.coralprotocol.coralserver.routes.api.v1.agentApiRoutes +import org.coralprotocol.coralserver.routes.api.v1.documentationApiRoutes +import org.coralprotocol.coralserver.routes.api.v1.messageApiRoutes +import org.coralprotocol.coralserver.routes.api.v1.sessionApiRoutes +import org.coralprotocol.coralserver.routes.api.v1.telemetryApiRoutes +import org.coralprotocol.coralserver.routes.sse.v1.connectionSseRoutes +import org.coralprotocol.coralserver.routes.ws.v1.debugWsRoutes +import org.coralprotocol.coralserver.session.SessionManager +import kotlin.time.Duration.Companion.seconds + +private val logger = KotlinLogging.logger {} + +// It is important that this naming strategy is used both for deserialization and for the OpenAPI generation +// so that the spec generates the correct property names for (de)serialization +private val NAMING_STRATEGY = JsonNamingStrategy.SnakeCase + +private val json = Json { + encodeDefaults = true + prettyPrint = true + explicitNulls = false + namingStrategy = NAMING_STRATEGY +} + +/** + * CoralServer class that encapsulates the SSE MCP server functionality. + * + * @param host The host to run the server on + * @param port The port to run the server on + * @param devmode Whether the server is running in development mode + */ +class CoralServer( + val host: String = "0.0.0.0", + val port: UShort = 5555u, + val appConfig: ConfigCollection, + val devmode: Boolean = false, + val sessionManager: SessionManager = SessionManager(port = port), +) { + + /* + /{destinationId}/{protocol}/{version}/depends.... + + /api/v2/forward/{destinationId}/ get/head/post/etc -> ...agent traffic: telemetry + sse + anything else in the future + + local: + generic api: /api/v1/message/{applicationId}/{privacyKey}/{coralSessionId} + sse: * /sse/v1/1/{applicationId}/{privacyKey}/{coralSessionId} + coral studio: /ws/v1/1/debug/{applicationId}/{privacyKey}/{coralSessionId} + + */ + private val mcpServersByTransportId = ConcurrentMap() + private var server: EmbeddedServer = + embeddedServer(CIO, host = host, port = port.toInt(), watchPaths = listOf("classes")) { + install(OpenApi) { + info { + title = "Coral Server API" + } + spec("v1") { + info { + version = "1.0" + } + tags { + tagGenerator = { url -> listOf(url.getOrNull(2)) } + } + schemas { + generator = SchemaGenerator.kotlinx { } + // Generated types from routes + generator = { type -> + type + .analyzeTypeUsingKotlinxSerialization { + + } + .addMissingSupertypeSubtypeRelations() + .addJsonClassDiscriminatorProperty() + .handleNameAnnotation().renameMembers(NAMING_STRATEGY) + .generateSwaggerSchema({ + strictDiscriminatorProperty = true + }) + .handleCoreAnnotations() + .handleSchemaAnnotations() + .customizeTypes { _, schema -> + // Mapping is broken, and one of the code generation libraries I am using checks the + // references here + schema.discriminator?.mapping = null; + } + .withTitle(TitleType.SIMPLE) + .compileReferencingRoot( + explicitNullTypes = false, + inlineDiscriminatedTypes = true, + builder = TitleBuilder.BUILDER_OPENAPI_SIMPLE + ) + } + + // Other types, used by SSE or WebSocket routes + schema("SocketEvent") + } + } + specAssigner = { url: String, tags: List -> + // when another spec version is added, determine the version based on the url here or use + // specVersion on the new routes + "v1" + } + pathFilter = { method, parts -> + parts.getOrNull(0) == "api" + } + outputFormat = OutputFormat.JSON + } + install(Resources) + install(SSE) + install(ContentNegotiation) { + json(json, contentType = ContentType.Application.Json) + } + install(WebSockets) { + contentConverter = KotlinxWebsocketSerializationConverter(Json) + pingPeriod = 5.seconds + timeout = 15.seconds + maxFrameSize = Long.MAX_VALUE + masking = false + } + // TODO: probably restrict this down the line + install(CORS) { + allowMethod(HttpMethod.Options) + allowMethod(HttpMethod.Post) + allowMethod(HttpMethod.Get) + allowHeader(HttpHeaders.AccessControlAllowOrigin) + allowHeader(HttpHeaders.ContentType) + anyHost() + } + install(StatusPages) { + exception { call, cause -> + // Other exceptions should still be serialized, wrap non RouteException type exceptions in a + // RouteException, giving a 500-status code + var wrapped = cause + if (cause !is RouteException) { + wrapped = RouteException(HttpStatusCode.InternalServerError, cause.message) + } + + call.respondText(text = json.encodeToString(wrapped), status = wrapped.status) + } + } + routing { + // api + debugApiRoutes(sessionManager) + sessionApiRoutes(appConfig, sessionManager, devmode) + messageApiRoutes(mcpServersByTransportId, sessionManager) + telemetryApiRoutes(sessionManager) + documentationApiRoutes() + agentApiRoutes(appConfig, sessionManager) + + // sse + connectionSseRoutes(mcpServersByTransportId, sessionManager) + + // websocket + debugWsRoutes(sessionManager) + + // source of truth for OpenAPI docs/codegen + route("api_v1.json") { + openApi("v1") + } + } + } + val monitor get() = server.monitor + private var serverJob: Job? = null + + /** + * Starts the server. + */ + fun start(wait: Boolean = false) { + logger.info { "Starting sse server on port $port with ${appConfig.appConfig.applications.size} configured applications" } + System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "trace"); + + if (devmode) { + logger.info { + "In development, agents can connect to " + + "http://localhost:$port/sse/v1/exampleApplicationId/examplePrivacyKey/exampleSessionId/sse?agentId=exampleAgent" + } + logger.info { + "Connect the inspector to " + + "http://localhost:$port/sse/v1/devmode/exampleApplicationId/examplePrivacyKey/exampleSessionId/sse?agentId=inspector" + } + } + server.monitor.subscribe(ApplicationStarted) { + logger.info { "Server started on $host:$port" } + } + server.start(wait) + } + + /** + * Stops the server. + */ + fun stop() { + logger.info { "Stopping server..." } + serverJob?.cancel() + server.stop(1000, 2000) + serverJob = null + logger.info { "Server stopped" } + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/server/RouteException.kt b/src/main/kotlin/org/coralprotocol/coralserver/server/RouteException.kt new file mode 100644 index 0000000000000000000000000000000000000000..c0e4a5c9938731e18eed6a054d21ff10bf4c0cdf --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/server/RouteException.kt @@ -0,0 +1,19 @@ +package org.coralprotocol.coralserver.server + +import io.ktor.http.* +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.Transient + +@Serializable +data class RouteException( + @Transient + val status: HttpStatusCode = throw IllegalArgumentException("class cannot be deserialized"), + + @Suppress("unused") + @SerialName("message") + val routeExceptionMessage: String?) : Exception(routeExceptionMessage) +{ + @Suppress("unused") + val stackTrace = super.stackTrace.map { it.toString() } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/server/StdioServer.kt b/src/main/kotlin/org/coralprotocol/coralserver/server/StdioServer.kt new file mode 100644 index 0000000000000000000000000000000000000000..2bfe0a2b3d311b51d8074c83ff8b9d229b62f6ed --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/server/StdioServer.kt @@ -0,0 +1,33 @@ +//TODO: Determine whether we should allow connection via stdio +//package org.coralprotocol.coralserver.server +// +//import io.github.oshai.kotlinlogging.KotlinLogging +//import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport +//import kotlinx.coroutines.Job +//import kotlinx.coroutines.runBlocking +//import kotlinx.io.asSink +//import kotlinx.io.asSource +//import kotlinx.io.buffered +// +//private val logger = KotlinLogging.logger {} +// +///** +// * Runs an MCP server using standard input/output. +// * The server will handle listing prompts, tools, and resources automatically. +// */ +//fun runMcpServerUsingStdio() { +// // Note: The server will handle listing prompts, tools, and resources automatically. +// // The handleListResourceTemplates will return empty as defined in the Server code. +// val server = createCoralMcpServer() +// val transport = StdioServerTransport( +// inputStream = System.`in`.asSource().buffered(), +// outputStream = System.out.asSink().buffered() +// ) +// +// runBlocking { +// server.connect(transport) +// val done = Job() +// done.join() +// logger.info { "Server closed" } +// } +//} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/session/CoralAgentGraphSession.kt b/src/main/kotlin/org/coralprotocol/coralserver/session/CoralAgentGraphSession.kt new file mode 100644 index 0000000000000000000000000000000000000000..aa10590c4f8b0fe7cde48944d61f41ee4e5520a4 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/session/CoralAgentGraphSession.kt @@ -0,0 +1,317 @@ +package org.coralprotocol.coralserver.session + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.util.collections.* +import kotlinx.coroutines.CompletableDeferred +import org.coralprotocol.coralserver.EventBus +import org.coralprotocol.coralserver.agent.graph.AgentGraph +import org.coralprotocol.coralserver.models.* +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +private val logger = KotlinLogging.logger {} + +/** + * Session class to hold stateful information for a specific application and privacy key. + * [devRequiredAgentStartCount] is the number of agents that need to register before the session can proceed. This is for devmode only. + */ +class CoralAgentGraphSession( + val id: String, + val applicationId: String, + val privacyKey: String, + val agentGraph: AgentGraph?, + val coralAgentConnections: MutableList = mutableListOf(), + val groups: List> = listOf(), + var devRequiredAgentStartCount: Int = 0, +) { + private var agents = ConcurrentHashMap() + private val debugAgents = ConcurrentSet() + + private val threads = ConcurrentHashMap() + + private val agentNotifications = ConcurrentHashMap>>() + + private val lastReadMessageIndex = ConcurrentHashMap, Int>() + + private val agentGroupScheduler = GroupScheduler(groups) + private val countBasedScheduler = CountBasedScheduler() + + private val eventBus = EventBus() + val events get() = eventBus.events + + init { + agentGraph?.run { + for (id in agents.keys) { + registerAgent(id.toString()) + setAgentState(agentId = id.toString(), state = AgentState.Connecting) + } + } + } + + + fun getAllThreadsAgentParticipatesIn(agentId: String): List { + return threads.values.filter { it.participants.contains(agentId) } + } + + fun clearAll() { + agents.clear() + threads.clear() + agentNotifications.clear() + lastReadMessageIndex.clear() + countBasedScheduler.clear() + agentGroupScheduler.clear() + } + + fun connectAgent(agentId: String): Agent? { + val agent = agents[agentId] ?: return null +// if (agent.state.isConnected()) throw AssertionError("Agent $agentId is already connected") + if (agent.state.isConnected()) logger.warn { "Agent $agentId is already connected" } + agent.state = AgentState.Busy; + agentGroupScheduler.markAgentReady(agentId) + countBasedScheduler.markAgentReady(agent.id) + eventBus.emit(SessionEvent.AgentStateUpdated(agent.id, agent.state)) + return agent + } + + fun setAgentState(agentId: String, state: AgentState): AgentState? { + val agent = agents[agentId] ?: return null + val oldState = agent.state + if (oldState == AgentState.Connecting && state.isConnected()) { + agentGroupScheduler.markAgentReady(agentId) + countBasedScheduler.markAgentReady(agent.id) + } + agent.state = state + eventBus.emit(SessionEvent.AgentStateUpdated(agent.id, agent.state)) + return oldState + + } + + fun disconnectAgent(agentId: String) { + val agent = agents[agentId] ?: return + agent.state = AgentState.Disconnected; + eventBus.emit(SessionEvent.AgentStateUpdated(agent.id, agent.state)) + } + + fun registerAgent( + agentId: String, + agentUri: String? = null, + agentDescription: String? = null, + force: Boolean = false + ): Agent? { + if (agents.containsKey(agentId)) { + logger.warn { "$agentId has already been registered" } + if (!force) { + return null; + } + } + val agent = Agent( + id = agentId, + description = agentDescription ?: "", + extraTools = agentGraph?.let { + it.agents[agentId]?.extraTools?.mapNotNull { tool -> it.tools[tool] }?.toSet() + } ?: setOf(), + mcpUrl = agentUri + ) + agents[agent.id] = agent + + return agent + } + + fun getRegisteredAgentsCount(): Int = countBasedScheduler.getRegisteredAgentsCount() + + suspend fun waitForGroup(agentId: String, timeoutMs: Long): Boolean = + agentGroupScheduler.waitForGroup(agentId, timeoutMs) + + suspend fun waitForAgentCount(targetCount: Int, timeoutMs: Long): Boolean = + countBasedScheduler.waitForAgentCount(targetCount, timeoutMs) + + fun getAgent(agentId: String): Agent? = agents[agentId] + + fun getAllAgents(includeDebug: Boolean = false): List = when (includeDebug) { + true -> agents.values.toList() + false -> agents.values.filter { !debugAgents.contains(it.id) } + } + + fun getAllThreads(): List = threads.values.toList() + + fun registerDebugAgent(): Agent { + val id = UUID.randomUUID().toString() + if (agents[id] !== null) throw AssertionError("Debug agent id collision") + val agent = Agent(id = id, description = "", mcpUrl = "n/a") + agents[id] = agent + debugAgents.add(id) + return agent + } + + fun createThread(name: String, creatorId: String, participantIds: List): Thread { + if (creatorId != "debug" && !agents.containsKey(creatorId)) { + throw IllegalArgumentException("Creator agent $creatorId not found") + } + + val validParticipants = participantIds.filter { agents.containsKey(it) }.toMutableList() + + if (!validParticipants.contains(creatorId)) { + validParticipants.add(creatorId) + } + + val thread = Thread( + name = name, + creatorId = creatorId, + participants = validParticipants + ) + threads[thread.id] = thread + + eventBus.emit( + SessionEvent.ThreadCreated( + id = thread.id, + name = name, + creatorId = creatorId, + participants = validParticipants, + summary = null + ) + ) + return thread + } + + fun getThread(threadId: String): Thread? = threads[threadId] + + fun getThreadsForAgent(agentId: String): List { + return threads.values.filter { it.participants.contains(agentId) } + } + + fun addParticipantToThread(threadId: String, participantId: String): Boolean { + val thread = threads[threadId] ?: return false + val agent = agents[participantId] ?: return false + + if (thread.isClosed) return false + + if (!thread.participants.contains(participantId)) { + thread.participants.add(participantId) + lastReadMessageIndex[Pair(participantId, threadId)] = thread.messages.size + } + return true + } + + fun removeParticipantFromThread(threadId: String, participantId: String): Boolean { + val thread = threads[threadId] ?: return false + + if (thread.isClosed) return false + + return thread.participants.remove(participantId) + } + + fun closeThread(threadId: String, summary: String): Boolean { + val thread = threads[threadId] ?: return false + + thread.isClosed = true + thread.summary = summary + + return true + } + + fun getColorForSenderId(senderId: String): String { + val colors = listOf( + "#FF5733", "#33FF57", "#3357FF", "#F3FF33", "#FF33F3", + "#33FFF3", "#FF8033", "#8033FF", "#33FF80", "#FF3380" + ) + val hash = senderId.hashCode() + val index = Math.abs(hash) % colors.size + return colors[index] + } + + fun sendMessage( + threadId: String, + senderId: String, + content: String, + mentions: List = emptyList() + ): Message { + val thread = getThread(threadId) ?: throw IllegalArgumentException("Thread with id $threadId not found") + val sender = getAgent(senderId) ?: throw IllegalArgumentException("Agent with id $senderId not found") + + val message = Message.create(thread, sender, content, mentions) + thread.messages.add(message) + eventBus.emit(SessionEvent.MessageSent(threadId, message.resolve())) + notifyMentionedAgents(message) + return message + } + + private fun notifyMentionedAgents(message: Message) { + if (message.sender.id == "system") { + val thread = threads[message.thread.id] ?: return + thread.participants.forEach { participantId -> + val deferred = agentNotifications[participantId] + if (deferred != null && !deferred.isCompleted) { + deferred.complete(listOf(message)) + } + } + return + } + + message.mentions.forEach { mentionId -> + val deferred = agentNotifications[mentionId] + if (deferred != null && !deferred.isCompleted) { + deferred.complete(listOf(message)) + } + } + } + + suspend fun waitForMentions(agentId: String, timeoutMs: Long): List { + if (timeoutMs <= 0) { + throw IllegalArgumentException("Timeout must be greater than 0") + } + + val agent = agents[agentId] ?: return emptyList() + + val unreadMessages = getUnreadMessagesForAgent(agentId) + if (unreadMessages.isNotEmpty()) { + updateLastReadIndices(agentId, unreadMessages) + return unreadMessages + } + + val deferred = CompletableDeferred>() + agentNotifications[agentId] = deferred + + val result = kotlinx.coroutines.withTimeoutOrNull(timeoutMs) { + deferred.await() + } ?: emptyList() + + agentNotifications.remove(agentId) + + updateLastReadIndices(agentId, result) + + return result + } + + fun getUnreadMessagesForAgent(agentId: String): List { + val agent = agents[agentId] ?: return emptyList() + + val result = mutableListOf() + + val agentThreads = getThreadsForAgent(agentId) + + for (thread in agentThreads) { + val lastReadIndex = lastReadMessageIndex[Pair(agentId, thread.id)] ?: 0 + + val unreadMessages = thread.messages.subList(lastReadIndex, thread.messages.size) + + result.addAll(unreadMessages.filter { + it.mentions.contains(agentId) || it.sender.id == "system" + }) + } + + return result + } + + fun updateLastReadIndices(agentId: String, messages: List) { + val messagesByThread = messages.groupBy { it.thread } + + for ((thread, threadMessages) in messagesByThread) { + val messageIndices = threadMessages.map { thread.messages.indexOf(it) } + if (messageIndices.isNotEmpty()) { + val maxIndex = messageIndices.maxOrNull() ?: continue + lastReadMessageIndex[Pair(agentId, thread.id)] = maxIndex + 1 + } + } + } +} diff --git a/src/main/kotlin/org/coralprotocol/coralserver/session/CountBasedScheduler.kt b/src/main/kotlin/org/coralprotocol/coralserver/session/CountBasedScheduler.kt new file mode 100644 index 0000000000000000000000000000000000000000..264ee59e583c6670385432bce7ce59a847a01730 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/session/CountBasedScheduler.kt @@ -0,0 +1,78 @@ +package org.coralprotocol.coralserver.session + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.withTimeoutOrNull +import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.atomics.AtomicInt +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.concurrent.atomics.incrementAndFetch + +class CountBasedScheduler() { + private val agentCountNotifications = ConcurrentHashMap>>() + @OptIn(ExperimentalAtomicApi::class) + private var registeredAgentsCount = AtomicInt(0) + + @OptIn(ExperimentalAtomicApi::class) + fun getRegisteredAgentsCount(): Int = registeredAgentsCount.load() + + @OptIn(ExperimentalAtomicApi::class) + fun clear() { + agentCountNotifications.clear() + registeredAgentsCount.store(0) + } + + @OptIn(ExperimentalAtomicApi::class) + fun markAgentReady(agentId: String) { + registeredAgentsCount.incrementAndFetch() + + // Create a copy of the keys to avoid ConcurrentModificationException + val targetCounts = agentCountNotifications.keys.toList() + + // For each target count that has been reached + for (targetCount in targetCounts) { + if (registeredAgentsCount.load() >= targetCount) { + // Get the list of deferreds for this target count + val deferredList = agentCountNotifications[targetCount] + if (deferredList != null) { + // Complete all deferreds that are not already completed + for (deferred in deferredList) { + if (!deferred.isCompleted) { + deferred.complete(true) + } + } + // Remove this target count from the map + agentCountNotifications.remove(targetCount) + } + } + } + } + + + @OptIn(ExperimentalAtomicApi::class) + suspend fun waitForAgentCount(targetCount: Int, timeoutMs: Long): Boolean { + if (registeredAgentsCount.load() >= targetCount) return true + + val deferred = CompletableDeferred() + + val deferredList = agentCountNotifications.computeIfAbsent(targetCount) { mutableListOf() } + deferredList.add(deferred) + + val result = withTimeoutOrNull(timeoutMs) { + deferred.await() + } ?: false + + if (!result) { + // If the wait timed out, remove this deferred from the list + val deferredsList = agentCountNotifications[targetCount] + if (deferredsList != null) { + deferredsList.remove(deferred) + // If the list is now empty, remove the target count from the map + if (deferredsList.isEmpty()) { + agentCountNotifications.remove(targetCount) + } + } + } + + return result + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/session/GroupScheduler.kt b/src/main/kotlin/org/coralprotocol/coralserver/session/GroupScheduler.kt new file mode 100644 index 0000000000000000000000000000000000000000..95494679bddd54a05c2af0dc7e62a46e56dfbb35 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/session/GroupScheduler.kt @@ -0,0 +1,77 @@ +package org.coralprotocol.coralserver.session + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.withTimeoutOrNull +import java.util.concurrent.ConcurrentHashMap +import kotlin.concurrent.atomics.AtomicInt +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.concurrent.atomics.incrementAndFetch + +data class GroupScheduler( + val groups: List> = listOf(), +) { + @OptIn(ExperimentalAtomicApi::class) + private val agentGroupRegisteredCount = List(groups.size) { i -> Pair(i, AtomicInt(0)) }.toMap() + private val agentGroupRequiredCount = groups.mapIndexed { i, grp -> Pair(i, grp.size) }.toMap() + private val agentGroupMembership = groups.flatMapIndexed { i, grp -> grp.map { agent -> Pair(agent, i) } }.toMap() + private val agentGroupNotifications = ConcurrentHashMap>>() + + + @OptIn(ExperimentalAtomicApi::class) + fun clear() { + agentGroupNotifications.clear() + agentGroupRegisteredCount.values.forEach{it -> it.store(0)} + } + + @OptIn(ExperimentalAtomicApi::class) + fun markAgentReady(agentId: String) { + // Increment the counter of agents registered for the agent's blocking group + agentGroupMembership[agentId]?.let { group -> + agentGroupRegisteredCount[group]?.let { + val count = it.incrementAndFetch() + val required = agentGroupRequiredCount[group] + ?: throw AssertionError("Group $group has registered counter, but no required count") + + if (count >= required) { + agentGroupNotifications[group]?.forEach { listener -> + if (!listener.isCompleted) listener.complete( + true + ) + } + } + } + } + } + + @OptIn(ExperimentalAtomicApi::class) + suspend fun waitForGroup(agentId: String, timeoutMs: Long): Boolean { + // if there's no group for this agent, the agent must be non-blocking + val group = agentGroupMembership[agentId] ?: return true + val required = agentGroupRequiredCount[group] + ?: throw AssertionError("Group $group implied through membership, but no required count exists") + + if ((agentGroupRegisteredCount[group]?.load() ?: 0) >= required) return true + + val deferred = CompletableDeferred() + val deferredList = agentGroupNotifications.computeIfAbsent(group) { mutableListOf() } + + deferredList.add(deferred) + + val result = withTimeoutOrNull(timeoutMs) { + deferred.await() + } ?: false + + if (!result) { + // If the wait timed out, remove this deferred from the list + agentGroupNotifications[group]?.let { + it.remove(deferred) + // If the list is now empty, remove the target count from the map + if (it.isEmpty()) { + agentGroupNotifications.remove(group) + } + } + } + + return result + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/session/SessionEvent.kt b/src/main/kotlin/org/coralprotocol/coralserver/session/SessionEvent.kt new file mode 100644 index 0000000000000000000000000000000000000000..cfc01d872cd9b53a4d35ad29560f76e4299d519b --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/session/SessionEvent.kt @@ -0,0 +1,35 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.session + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator +import org.coralprotocol.coralserver.models.Agent +import org.coralprotocol.coralserver.models.AgentState +import org.coralprotocol.coralserver.models.ResolvedMessage + +@Serializable +@JsonClassDiscriminator("type") +sealed interface SessionEvent { + @Serializable + @SerialName("agent_registered") + data class AgentRegistered(val agent: Agent) : SessionEvent + + @Serializable + @SerialName("agent_state_updated") + data class AgentStateUpdated(val agentId: String, val state: AgentState): SessionEvent + + @Serializable + @SerialName("agent_ready") + data class AgentReady(val agent: String): SessionEvent + + @Serializable + @SerialName("thread_created") + data class ThreadCreated(val id: String, val name: String, val creatorId: String, val participants: List, val summary: String?): SessionEvent + + @Serializable + @SerialName("message_sent") + data class MessageSent(val threadId: String, val message: ResolvedMessage): SessionEvent +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/session/SessionManager.kt b/src/main/kotlin/org/coralprotocol/coralserver/session/SessionManager.kt new file mode 100644 index 0000000000000000000000000000000000000000..e8725b47d336a2ac1ab985bd4c319089d9a92b24 --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/session/SessionManager.kt @@ -0,0 +1,156 @@ +package org.coralprotocol.coralserver.session + +import com.chrynan.uri.core.Uri +import com.chrynan.uri.core.fromParts +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.sync.withPermit +import kotlinx.coroutines.withTimeoutOrNull +import org.coralprotocol.coralserver.agent.graph.AgentGraph +import org.coralprotocol.coralserver.agent.runtime.Orchestrator +import java.util.concurrent.ConcurrentHashMap + +fun AgentGraph.adjacencyMap(): Map> { + val map = mutableMapOf>() + + // each set in the set of links defines one strongly connected component (scc), + // where each member of the scc is bidirectionally connected to every other member of the scc + links.forEach { scc -> + for (a in scc) { + for (b in scc) { + if (a == b) continue + map.getOrPut(a) { mutableSetOf() }.add(b) + map.getOrPut(b) { mutableSetOf() }.add(a) + } + } + } + return map +} + +/** + * Session manager to create and retrieve sessions. + */ +class SessionManager(val orchestrator: Orchestrator = Orchestrator(), val port: UShort) { + private val sessions = ConcurrentHashMap() + private val sessionSemaphore = Semaphore(1) + + private val sessionListeners = ConcurrentHashMap>>() + + suspend fun waitForSession(id: String, timeoutMs: Long = 10000): CoralAgentGraphSession? { + if (sessions.containsKey(id)) return sessions[id] + val deferred = CompletableDeferred() + sessionListeners.computeIfAbsent(id) { mutableListOf() }.add(deferred) + + val result = withTimeoutOrNull(timeoutMs) { + deferred.await() + } ?: false + + if (!result) { + // If the wait timed out, remove this deferred from the list + sessionListeners[id]?.let { + it.remove(deferred) + // If the list is now empty, remove the target count from the map + if (it.isEmpty()) { + sessionListeners.remove(id) + } + } + } + + return sessions[id] + } + + /** + * Create a new session with a random ID. + */ + fun createSession(applicationId: String, privacyKey: String, agentGraph: AgentGraph? = null): CoralAgentGraphSession = + createSessionWithId(java.util.UUID.randomUUID().toString(), applicationId, privacyKey, agentGraph) + + /** + * Create a new session with a specific ID. + */ + fun createSessionWithId( + sessionId: String, + applicationId: String, + privacyKey: String, + agentGraph: AgentGraph? = null + ): CoralAgentGraphSession { + val subgraphs = agentGraph?.let { it -> + + val adj = it.adjacencyMap() + val visited = mutableSetOf() + val subgraphs = mutableListOf>() + + // flood fill to find all disconnected subgraphs + for (node in adj.keys) { + if (visited.contains(node)) continue + // non-blocking agents should not be considered part of any subgraph + if (it.agents[node]?.blocking == false) continue + + val subgraph = mutableSetOf(node) + val toVisit = adj[node]?.toMutableList() + while (toVisit?.isNotEmpty() == true) { + val next = toVisit.removeLast() + if (visited.contains(next)) continue + // non-blocking agents should not be considered part of any subgraph + if (it.agents[next]?.blocking == false) continue + subgraph.add(next) + visited.add(next) + adj[next]?.let { n -> toVisit.addAll(n) } + } + subgraphs.add(subgraph) + visited.add(node) + } + + it.agents.forEach { agent -> + orchestrator.spawn( + sessionId = sessionId, + graphAgent = agent.value, + port = port, + agentName = agent.key.toString(), + relativeMcpServerUri = Uri.fromParts(scheme = "http", path = "/sse/v1/${applicationId}/${privacyKey}/${sessionId}/sse", query = "agentId=${agent.key}"), + sessionManager = this, + ) + } + subgraphs + } + val session = CoralAgentGraphSession(sessionId, applicationId, privacyKey, agentGraph = agentGraph, groups = subgraphs?.toList() ?: emptyList()) + sessions[sessionId] = session + sessionListeners[sessionId]?.let { it -> + it.forEach { + if (!it.isCompleted) { + it.complete(true) + } + } + } + return session + } + + /** + * Get or create a session with a specific ID. + * If the session exists, return it. Otherwise, create a new one. + */ + suspend fun getOrCreateSession( + sessionId: String, + applicationId: String, + privacyKey: String, + agentGraph: AgentGraph? = null + ): CoralAgentGraphSession { + sessionSemaphore.withPermit { + return sessions[sessionId] ?: createSessionWithId(sessionId, applicationId, privacyKey, agentGraph) + } + } + + /** + * Get a session by ID. + */ + fun getSession(sessionId: String): CoralAgentGraphSession? { + return sessions[sessionId] + } + + /** + * Get all sessions. + */ + fun getAllSessions(): List { + return sessions.values.toList() + } +} \ No newline at end of file diff --git a/src/main/kotlin/org/coralprotocol/coralserver/session/SessionModels.kt b/src/main/kotlin/org/coralprotocol/coralserver/session/SessionModels.kt new file mode 100644 index 0000000000000000000000000000000000000000..6342c961f0da9d192e33b6ce860a5e3cc835bebd --- /dev/null +++ b/src/main/kotlin/org/coralprotocol/coralserver/session/SessionModels.kt @@ -0,0 +1,156 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package org.coralprotocol.coralserver.session + +import com.chrynan.uri.core.UriString +import io.github.oshai.kotlinlogging.KotlinLogging +import io.github.smiley4.schemakenerator.core.annotations.Description +import io.ktor.client.* +import io.ktor.client.engine.cio.* +import io.ktor.client.plugins.contentnegotiation.* +import io.ktor.client.request.* +import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.serialization.kotlinx.json.* +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.json.* +import net.pwall.json.schema.JSONSchema +import org.coralprotocol.coralserver.agent.graph.GraphAgentRequest +import org.coralprotocol.coralserver.agent.runtime.RuntimeId +import org.coralprotocol.coralserver.server.CoralAgentIndividualMcp + +private val logger = KotlinLogging.logger {} + +/** + * Data class for session creation request. + */ +@Serializable +data class CreateSessionRequest( + val applicationId: String, + val sessionId: String? = null, + val privacyKey: String, + val agentGraph: AgentGraphRequest?, +) + +@Serializable +data class AgentGraphRequest( + val agents: HashMap, + val links: Set>, + val tools: Map = emptyMap(), +) + +object JSONSchemaSerializer : KSerializer { + // Serial names of descriptors should be unique, so choose app-specific name in case some library also would declare a serializer for Date. + override val descriptor: SerialDescriptor = buildClassSerialDescriptor("org.coralprotocol.JSONSchemaWithRaw") { + + } + + override fun serialize(encoder: Encoder, value: JSONSchemaWithRaw) { + val json = encoder as? JsonEncoder ?: throw SerializationException("Can be serialized only as JSON") + return json.encodeJsonElement(value.raw) + } + + override fun deserialize(decoder: Decoder): JSONSchemaWithRaw { + val jsonInput = decoder as? JsonDecoder ?: error("Can be deserialized only by JSON") + val obj = jsonInput.decodeJsonElement().jsonObject; + + return JSONSchemaWithRaw(schema = JSONSchema.parse(obj.toString()), raw = obj) + } +} + +@Serializable(with = JSONSchemaSerializer::class) +data class JSONSchemaWithRaw( + val schema: JSONSchema, + val raw: JsonObject, +) + +@Serializable +data class CustomTool( + val transport: ToolTransport, + val toolSchema: Tool, +) + +fun CoralAgentIndividualMcp.addExtraTool(sessionId: String, agentId: String, tool: CustomTool) { + addTool( + name = tool.toolSchema.name, + description = tool.toolSchema.description ?: "", + inputSchema = tool.toolSchema.inputSchema, + ) { request -> + tool.transport.handleRequest(sessionId, agentId, request, tool.toolSchema) + } +} + + +@Serializable +@JsonClassDiscriminator("type") +sealed interface ToolTransport { + @SerialName("http") + @Serializable + data class Http(val url: UriString) : ToolTransport { + override suspend fun handleRequest( + sessionId: String, + agentId: String, + request: CallToolRequest, + toolSchema: Tool + ): CallToolResult { + try { + val client = HttpClient(CIO) { + install(ContentNegotiation) { + json() + } + engine { + requestTimeout = 0 + } + } + + val response = client.post(url.value) { + url { + appendPathSegments(sessionId, agentId) + } + contentType(ContentType.Application.Json) + setBody(request.arguments) + } + + val body = response.bodyAsText() + return CallToolResult( + content = listOf(TextContent(body)) + ) + } catch (ex: Exception) { + logger.error(ex) { "Error occurred while executing request" } + return CallToolResult( + isError = true, + content = listOf(TextContent("Error: $ex")) + ) + } + } + } + + suspend fun handleRequest( + sessionId: String, + agentId: String, + request: CallToolRequest, + toolSchema: Tool + ): CallToolResult +} + +/** + * Data class for session creation response. + */ +@Serializable +data class CreateSessionResponse( + val sessionId: String, + val applicationId: String, + val privacyKey: String +) \ No newline at end of file diff --git a/src/main/resources/application.yaml b/src/main/resources/application.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e29ec336b19fedb7f73f0c96a3566741463b007 --- /dev/null +++ b/src/main/resources/application.yaml @@ -0,0 +1,16 @@ +# Default application configuration +# TODO: Applications are a work in progress. This is safe to ignore for now. +applications: + - id: "app" + name: "Default Application" + description: "Default application for testing" + privacyKeys: + - "default-key" + - "public" + - "priv" + +# Uncomment to configure an external application source +# applicationSource: +# type: "http" +# url: "https://example.com/applications" +# refreshIntervalSeconds: 3600 \ No newline at end of file diff --git a/src/main/resources/registry.toml b/src/main/resources/registry.toml new file mode 100644 index 0000000000000000000000000000000000000000..c4bfefd2742383c29c6b8978cd8590075f922dfa --- /dev/null +++ b/src/main/resources/registry.toml @@ -0,0 +1,30 @@ +# +# The agent import section defines a list of agents that are available to the Coral Server that uses them; agents can be +# provided either by: +# - A path to an agent folder that contains a coral-agent.toml file +# - A Git repo (identified by tag, branch or revision) that contains a coral-agent.toml file at the root level +# - A marketplace name +[agent-import] + +# +# One +interface = { path = "examples/camel-search-maths/interface" } + +math = { path = "examples/camel-search-maths/math" } +search = { path = "examples/camel-search-maths/search" } +#git-example1 = { git = "https://github.com/Coral-Protocol/coral-cli.git", rev = "ffffff" } +#git-example2 = { git = "https://github.com/Coral-Protocol/coral-cli.git", branch = "main" } +#git-example3 = { git = "https://github.com/Coral-Protocol/coral-cli.git", tag = "v1" } + + +[agent-export] +# name must match an imported agent +# for hackathon this is not publicized + +[agent-export.interface] +quantity = 10 + +# Actual runtime pricing format is not yet determined +[agent-export.interface.runtimes] +executable = { min_price = 512.0, max_price = 1024.0 } +docker = { min_price = 1024.0, max_price = 4096.0 } \ No newline at end of file diff --git a/src/test/kotlin/org/coralprotocol/coralserver/e2e/E2EResourceTest.kt b/src/test/kotlin/org/coralprotocol/coralserver/e2e/E2EResourceTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..66ce8019e2f781472e882ad43858c0467f26a61d --- /dev/null +++ b/src/test/kotlin/org/coralprotocol/coralserver/e2e/E2EResourceTest.kt @@ -0,0 +1,53 @@ +package org.coralprotocol.coralserver.e2e + +import kotlinx.coroutines.* +import org.eclipse.lmos.arc.agents.agent.ask +import org.eclipse.lmos.arc.core.Success +import org.junit.jupiter.api.BeforeEach +import kotlin.test.Test + + +class E2EResourceTest { + val port: UShort = 14391u + var server = TestCoralServer(port = port, devmode = true) + + @OptIn(DelicateCoroutinesApi::class) + @BeforeEach + fun setup() { + server.setup() + } + + @Test + fun testCreateThreadAndPostMessage(): Unit = runBlocking { + var allAssertsCompleted = false + createSessionWithConnectedAgents(server.server!!, sessionId = "test", privacyKey = "aaa", applicationId = "aaa", noAgentsOptional = true) { + val agent1 = agent("testAgent1", "testAgent1") + val agent2 = agent("testAgent2", "testAgent2") + + onAgentsCreated = { + agent1.getConnected().ask("Say hello to testAgent2 in a new thread. Tell it the passcode 3243") + val sessions = server.sessionManager.getAllSessions() + assert(sessions.size == 1) { "There should be one session" } + val session = sessions.first() + val threads = session.getAllThreads() + assert(threads.size == 1) { "There should be one thread" } + val thread = threads.first() + val messages = thread.messages + assert(messages.size == 1) { "There should be one message" } + + // Verify agent2 can receive the message + val agent2Response = agent2.getConnected().ask("What is the passcode testAgent1 just told you? use wait for mentions to check") as Success + assert(agent2Response.value.contains("3243")) { "Agent2 should receive the code from agent1" } + + // Verify agent2 can send back a message in the same thread + val agent2Response2 = agent2.getConnected().ask("The passcode is 9920. Pass it to testAgent1 in the same thread") as Success + val agent1PasscodeFrom2Resp = agent1.getConnected().ask("What is the passcode? testAgent2 just told you? use wait for mentions to check") as Success + + assert(agent1PasscodeFrom2Resp.value.contains("9920")) { "Agent1 should receive the right code returning from agent 2" } + assert(session.getAllThreads().size == 1) { "There should still be one thread" } + allAssertsCompleted = true + } + } + assert(allAssertsCompleted) { "All asserts completed." } + } +} diff --git a/src/test/kotlin/org/coralprotocol/coralserver/e2e/TestAgentUtils.kt b/src/test/kotlin/org/coralprotocol/coralserver/e2e/TestAgentUtils.kt new file mode 100644 index 0000000000000000000000000000000000000000..21e73ca52dc4148fb9436d4893dc1c071ba1f722 --- /dev/null +++ b/src/test/kotlin/org/coralprotocol/coralserver/e2e/TestAgentUtils.kt @@ -0,0 +1,290 @@ +package org.coralprotocol.coralserver.e2e + +import com.azure.ai.openai.OpenAIClientBuilder +import com.azure.core.credential.KeyCredential +import io.mockk.every +import io.mockk.mockkStatic +import io.modelcontextprotocol.util.Utils.resolveUri +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.DelicateCoroutinesApi +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import kotlinx.coroutines.newFixedThreadPoolContext +import kotlinx.coroutines.runBlocking +import org.coralprotocol.coralserver.config.ConfigCollection +import org.coralprotocol.coralserver.server.CoralServer +import org.coralprotocol.coralserver.session.CoralAgentGraphSession +import org.coralprotocol.coralserver.session.SessionManager +import org.eclipse.lmos.arc.agents.ChatAgent +import org.eclipse.lmos.arc.agents.agent.ask +import org.eclipse.lmos.arc.agents.agents +import org.eclipse.lmos.arc.agents.dsl.AllTools +import org.eclipse.lmos.arc.agents.llm.AIClientConfig +import org.eclipse.lmos.arc.agents.llm.MapChatCompleterProvider +import org.eclipse.lmos.arc.client.azure.AzureAIClient +import org.eclipse.lmos.arc.mcp.McpTools +import java.net.URI +import java.net.http.HttpRequest +import kotlin.time.Duration.Companion.seconds +import kotlin.time.toJavaDuration + +suspend fun createSessionWithConnectedAgents( + server: CoralServer, + sessionId: String, + privacyKey: String, + applicationId: String, + noAgentsOptional: Boolean = true, + agentsInSessionBlock: SessionCoralAgentDefinitionContext.() -> Unit, +): CoralAgentGraphSession { + val session = server.sessionManager.getOrCreateSession( + sessionId = sessionId, + applicationId = applicationId, + privacyKey = privacyKey, + agentGraph = null, + ) + + val context = BasicSessionCoralAgentDefinitionContext(server) + agentsInSessionBlock(context) + if (noAgentsOptional) { + session.devRequiredAgentStartCount = context.buildChatAgents(session).size + } + context.buildChatAgents(session) + context.onAgentsCreated(SessionCoralAgentDefinitionContext.AgentsCreatedContext()) + return session +} + +interface SessionCoralAgentDefinitionContext { + val server: CoralServer + fun agent( + name: String, + description: String = name, + systemPrompt: String = defaultSystemPrompt, + modelName: String = "gpt-4o", + agentClient: AzureAIClient = createTestAIClient(), + ): Deferred + + class AgentsCreatedContext { + // TODO: Consider using a more specific type than Deferred + // TODO: Consider using more guaranteed invokeCompleted + @OptIn(ExperimentalCoroutinesApi::class) + suspend fun Deferred.getConnected() = this.getCompleted() + } + + var onAgentsCreated: suspend AgentsCreatedContext.() -> Unit +} + +private class BasicSessionCoralAgentDefinitionContext(override val server: CoralServer) : + SessionCoralAgentDefinitionContext { + override var onAgentsCreated: suspend SessionCoralAgentDefinitionContext.AgentsCreatedContext.() -> Unit = { } + private val agentsToAdd = mutableListOf ChatAgent>() + override fun agent( + name: String, + description: String, + systemPrompt: String, + modelName: String, + agentClient: AzureAIClient, + ): Deferred { + val deferrable = CompletableDeferred() + agentsToAdd.add { session -> + val createConnectedCoralAgent = createConnectedCoralAgent( + server = server, + namePassedToServer = name, + descriptionPassedToServer = description, + systemPrompt = systemPrompt, + agentClient = agentClient, + modelName = modelName, + sessionId = session.id, + privacyKey = session.privacyKey, + applicationId = session.applicationId + ) + deferrable.complete(createConnectedCoralAgent) + return@add createConnectedCoralAgent + } + return deferrable + } + + val context = newFixedThreadPoolContext(5, "E2EResourceTest") + suspend fun buildChatAgents(session: CoralAgentGraphSession): List = coroutineScope { + return@coroutineScope agentsToAdd.map { + async(context) { it(session) } + }.awaitAll() + } +} + + +fun createConnectedCoralAgent( + server: CoralServer, + namePassedToServer: String, + descriptionPassedToServer: String = namePassedToServer, + systemPrompt: String = "You are a helpful assistant.", + agentClient: AzureAIClient = createTestAIClient(), + modelName: String = "gpt-4o", + sessionId: String = "session1", + privacyKey: String = "privkey", + applicationId: String = "exampleApplication", +): ChatAgent = createConnectedCoralAgent( + "http", + server.host, + server.port, + namePassedToServer, + descriptionPassedToServer, + systemPrompt, + agentClient, + modelName, + sessionId, + privacyKey, + applicationId +) + +/** + * Creates a connected Coral agent. + * + * @param port The port to connect to. + * @param namePassedToServer The name of the agent passed to the server. + * @param descriptionPassedToServer The description of the agent passed to the server. + * @param systemPrompt The system prompt for the agent. + * @param agentClient The AzureAIClient to use. + * @param sessionId The session ID for the agent. + * @param privacyKey The privacy key for the agent. + * @param applicationId The application ID for the agent. + * @return The created agent. + */ + +fun createConnectedCoralAgent( + protocol: String = "http", + host: String = "localhost", + port: UShort, + namePassedToServer: String, + descriptionPassedToServer: String = namePassedToServer, + systemPrompt: String = "You are a helpful assistant.", + agentClient: AzureAIClient = createTestAIClient(), + modelName: String = "gpt-4o", + sessionId: String = "session1", + privacyKey: String = "privkey", + applicationId: String = "exampleApplication", +): ChatAgent { + val agent = agents( + functionLoaders = listOf( + McpTools( + "$protocol://$host:$port/devmode/$applicationId/$privacyKey/$sessionId/sse?agentId=$namePassedToServer", + 5000.seconds.toJavaDuration() + ) + ), + chatCompleterProvider = MapChatCompleterProvider(mapOf(modelName to agentClient)), + ) { + agent { + this@agent.name = namePassedToServer + this@agent.description = descriptionPassedToServer + + model { + modelName + } + + prompt { systemPrompt } + tools = AllTools + } + }.getAgents().first() as ChatAgent + runBlocking { + agent.ask("hi") // TODO: This is a hack to make sure the agent is connected. + //TODO: Make arc connect to MCP servers eagerly. + } + return agent +} + +/** + * Creates an AzureAIClient for testing. + * + * @return An AzureAIClient configured for testing. + */ +fun createTestAIClient(): AzureAIClient { + val config = AIClientConfig( + modelName = "gpt-4o", + apiKey = System.getenv("OPENAI_API_KEY"), + endpoint = "https://api.openai.com/v1", + client = "?" + ) + val azureOpenAIClient = OpenAIClientBuilder() + .endpoint(config.endpoint) + .credential(KeyCredential(config.apiKey)) + .buildAsyncClient() + + return AzureAIClient(config, azureOpenAIClient) +} + +class TestCoralServer( + val host: String = "0.0.0.0", + val port: UShort = 5555u, + val devmode: Boolean = false, + val sessionManager: SessionManager = SessionManager(port = port), +) { + var server: CoralServer? = null + + @OptIn(DelicateCoroutinesApi::class) + val serverContext = newFixedThreadPoolContext(1, "E2EResourceTest") + + @OptIn(DelicateCoroutinesApi::class) + fun setup() { + server?.stop() + server = CoralServer( + host = host, + port = port, + devmode = devmode, + sessionManager = sessionManager, + appConfig = ConfigCollection(null) + ) + GlobalScope.launch(serverContext) { + server!!.start() + } + patchMcpJavaContentType() + patchMcpJavaEndpointResolution() + } +} + + +private fun patchMcpJavaContentType() { + mockkStatic(HttpRequest::class) + every { HttpRequest.newBuilder() } answers { + println("MockK Interceptor [@BeforeEach]: HttpRequest.newBuilder() called. ") + val requestBuilder = callOriginal().headers("Content-Type", "application/json").timeout(40.seconds.toJavaDuration()) + return@answers requestBuilder + } +} + +private fun patchMcpJavaEndpointResolution() { + mockkStatic(io.modelcontextprotocol.util.Utils::class) + every { resolveUri(any(), any()) } answers { + val baseUrl = invocation.args[0] as URI + val endpointUrl = invocation.args[1] as String + println("MockK Interceptor [@BeforeEach]: Utils.resolveUri called with baseUrl='$baseUrl', endpointUrl='$endpointUrl'. ") + return@answers if (endpointUrl.contains("?sessionId")) { + // In this case the sessionId is an MCP sessionId, not a Coral sessionId. + // The resolution logic works in this case (though the original is resolving against a URI object) + baseUrl.resolve(endpointUrl) + } else { + baseUrl + } + } +} + +private val defaultSystemPrompt = """ +You have access to communication tools to interact with other agents. + +If there are no other agents, remember to re-list the agents periodically using the list tool. + +You should know that the user can't see any messages you send, you are expected to be autonomous and respond to the user only when you have finished working with other agents, using tools specifically for that. + +You can emit as many messages as you like before using that tool when you are finished or absolutely need user input. You are on a loop and will see a "user" message every 4 seconds, but it's not really from the user. + +Run the wait for mention tool when you are ready to receive a message from another agent. This is the preferred way to wait for messages from other agents. + +You'll only see messages from other agents since you last called the wait for mention tool. Remember to call this periodically. Also call this when you're waiting with nothing to do. + +Don't try to guess any numbers or facts, only use reliable sources. If you are unsure, ask other agents for help. + +If you have been given a simple task by the user, you can use the wait for mention tool once with a short timeout and then return the result to the user in a timely fashion. +""".trimIndent() \ No newline at end of file diff --git a/src/test/kotlin/org/coralprotocol/coralserver/models/Agent.kt b/src/test/kotlin/org/coralprotocol/coralserver/models/Agent.kt new file mode 100644 index 0000000000000000000000000000000000000000..0089aa532297824d56790ccf05edbf9ac2b54f2c --- /dev/null +++ b/src/test/kotlin/org/coralprotocol/coralserver/models/Agent.kt @@ -0,0 +1,57 @@ +import com.akuleshov7.ktoml.Toml +import io.ktor.server.html.insert +import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable +import kotlinx.serialization.decodeFromString +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.descriptors.serialDescriptor +import kotlinx.serialization.encodeToString +import kotlinx.serialization.encoding.CompositeDecoder +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.encoding.encodeCollection + +@Serializable +enum class Enum { + @Serializable + A, + @Serializable + B, + @Serializable + C +} + +@Serializable(with = EnumsSerializer::class) +data class Enums( + val enums: List +) + +object EnumsSerializer: KSerializer { + override val descriptor: SerialDescriptor + get() = buildClassSerialDescriptor("Enums", + serialDescriptor>()) + + override fun serialize(encoder: Encoder, value: Enums) { + //encoder.encodeCollection(descriptor, value.enums.size) {} + TODO("Not yet implemented") + } + + override fun deserialize(decoder: Decoder): Enums { + val enums = ArrayList() + val struct = decoder.beginStructure(descriptor) + while (struct.decodeElementIndex(descriptor) != CompositeDecoder.DECODE_DONE) { + enums.add(Enum.valueOf(struct.decodeStringElement(descriptor, 0))) + } + TODO("Not yet implemented") + } +} + +fun main() { + // val enums = Enums(listOf(Enum.A, Enum.B, Enum.C)) + + val toml = Toml() + val encoded = "enums = [\"A\"]"// toml.encodeToString(enums) + println(encoded) + println(toml.decodeFromString(encoded)) +} \ No newline at end of file diff --git a/src/test/kotlin/org/coralprotocol/coralserver/models/Graph.kt b/src/test/kotlin/org/coralprotocol/coralserver/models/Graph.kt new file mode 100644 index 0000000000000000000000000000000000000000..ede70f25080e45f099a121398a82820cd3c6ebc4 --- /dev/null +++ b/src/test/kotlin/org/coralprotocol/coralserver/models/Graph.kt @@ -0,0 +1,70 @@ +package org.coralprotocol.coralserver.models + +import kotlinx.serialization.json.Json +import org.coralprotocol.coralserver.agent.graph.GraphAgentProvider +import org.coralprotocol.coralserver.agent.graph.GraphAgentRequest +import org.coralprotocol.coralserver.agent.graph.GraphAgentServer +import org.coralprotocol.coralserver.agent.graph.GraphAgentServerAttribute +import org.coralprotocol.coralserver.agent.graph.GraphAgentServerAttributeType +import org.coralprotocol.coralserver.agent.graph.GraphAgentServerCustomScorer +import org.coralprotocol.coralserver.agent.graph.GraphAgentServerScorerEffect +import org.coralprotocol.coralserver.agent.graph.GraphAgentServerScoring +import org.coralprotocol.coralserver.agent.graph.GraphAgentServerSource +import org.coralprotocol.coralserver.agent.runtime.RuntimeId + +fun local(json: Json): String = + json.encodeToString(GraphAgentRequest( + agentName = "interface", + options = mapOf(), + blocking = false, + tools = setOf(), + provider = GraphAgentProvider.Local(RuntimeId.DOCKER) + )) + +fun remote(json: Json): String = + json.encodeToString(GraphAgentRequest( + agentName = "interface", + options = mapOf(), + blocking = false, + tools = setOf(), + provider = GraphAgentProvider.Remote( + runtime = RuntimeId.DOCKER, + serverSource = GraphAgentServerSource.Servers( + listOf( + GraphAgentServer("https://hackathon.coralprotocol.org:5555", + listOf( + GraphAgentServerAttribute.String(GraphAgentServerAttributeType.ATTESTED_BY, "Coral Team"), + )), + GraphAgentServer("https://coral.mycompany.com:5555", + listOf( + GraphAgentServerAttribute.String(GraphAgentServerAttributeType.ATTESTED_BY, "Myself"), + )) + ) + ), + serverScoring = GraphAgentServerScoring.Custom( + scorers = listOf( + GraphAgentServerCustomScorer.StringEqual( + type = GraphAgentServerAttributeType.ATTESTED_BY, + string = "Coral Team", + effect = GraphAgentServerScorerEffect.Flat(10.0) + ), + GraphAgentServerCustomScorer.StringEqual( + type = GraphAgentServerAttributeType.ATTESTED_BY, + string = "Myself", + effect = GraphAgentServerScorerEffect.Flat(20.0) + ) + ) + ) + ) + )) + +fun main() { + val json = Json { + encodeDefaults = true + prettyPrint = true + explicitNulls = false + } + + println(local(json)) + println(remote(json)) +} \ No newline at end of file diff --git a/src/test/kotlin/org/coralprotocol/coralserver/session/SessionManagerTest.kt b/src/test/kotlin/org/coralprotocol/coralserver/session/SessionManagerTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..e26f17a82cfbd368dc92b3969ca8e5c1f610572b --- /dev/null +++ b/src/test/kotlin/org/coralprotocol/coralserver/session/SessionManagerTest.kt @@ -0,0 +1,202 @@ +package org.coralprotocol.coralserver.session + +import kotlinx.coroutines.runBlocking +import org.coralprotocol.coralserver.models.Agent +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import java.util.concurrent.ConcurrentHashMap + +class SessionManagerTest { + + var sessionManager = SessionManager(port = 0u) + @BeforeEach + fun setup() { + sessionManager = SessionManager(port = 0u) + } + + @Test + fun `test create session with random ID`() { + // Create a session with a random ID + val session = sessionManager.createSession("app1", "key1") + + // Verify session was created with correct properties + assertNotNull(session) + assertEquals("app1", session.applicationId) + assertEquals("key1", session.privacyKey) + assertTrue(session.id.isNotEmpty()) + + // Verify session can be retrieved + val retrievedSession = sessionManager.getSession(session.id) + assertEquals(session, retrievedSession) + } + + @Test + fun `test create session with specific ID`() { + // Create a session with a specific ID + val session = sessionManager.createSessionWithId("session1", "app1", "key1") + + // Verify session was created with correct properties + assertNotNull(session) + assertEquals("session1", session.id) + assertEquals("app1", session.applicationId) + assertEquals("key1", session.privacyKey) + + // Verify session can be retrieved + val retrievedSession = sessionManager.getSession("session1") + assertEquals(session, retrievedSession) + } + + @Test + fun `test get or create session - existing session`() = runBlocking { + // Create a session first + sessionManager.createSessionWithId("session2", "app1", "key1") + + // Get the existing session + val session = sessionManager.getOrCreateSession("session2", "app2", "key2") + + // Verify the existing session is returned (not a new one with updated properties) + assertEquals("session2", session.id) + assertEquals("app1", session.applicationId) // Should still be app1, not app2 + assertEquals("key1", session.privacyKey) // Should still be key1, not key2 + } + + @Test + fun `test get or create session - new session`() = runBlocking { + // Get or create a new session + val session = sessionManager.getOrCreateSession("session3", "app3", "key3") + + // Verify a new session was created + assertEquals("session3", session.id) + assertEquals("app3", session.applicationId) + assertEquals("key3", session.privacyKey) + + // Verify session can be retrieved + val retrievedSession = sessionManager.getSession("session3") + assertEquals(session, retrievedSession) + } + + @Test + fun `test get session - non-existent session`() { + // Try to get a non-existent session + val session = sessionManager.getSession("nonexistent") + + // Verify null is returned + assertNull(session) + } + + @Test + fun `test get all sessions`() { + // Create multiple sessions + val session1 = sessionManager.createSessionWithId("session1", "app1", "key1") + val session2 = sessionManager.createSessionWithId("session2", "app2", "key2") + val session3 = sessionManager.createSessionWithId("session3", "app3", "key3") + + // Get all sessions + val sessions = sessionManager.getAllSessions() + + // Verify all sessions are returned + assertEquals(3, sessions.size) + assertTrue(sessions.contains(session1)) + assertTrue(sessions.contains(session2)) + assertTrue(sessions.contains(session3)) + } + + @Test + fun `test threads are not available across sessions with different privacy keys`() { + // Create two sessions with different privacy keys + val session1 = sessionManager.createSessionWithId("session1", "app1", "key1") + val session2 = sessionManager.createSessionWithId("session2", "app1", "key2") + + // Register agents in both sessions + val creator1 = session1.registerAgent(agentId = "creator1") ?: throw AssertionError("could not register agent") + val participant1 = session1.registerAgent(agentId = "participant1") ?: throw AssertionError("could not register agent") + + val creator2 = session2.registerAgent(agentId = "creator2") ?: throw AssertionError("could not register agent") + val participant2 = session2.registerAgent(agentId = "participant2") ?: throw AssertionError("could not register agent") + + // Create a thread in the first session + val thread1 = session1.createThread( + name = "Thread in Session 1", + creatorId = creator1.id, + participantIds = listOf("participant1") + ) + + // Create a thread in the second session + val thread2 = session2.createThread( + name = "Thread in Session 2", + creatorId = creator2.id, + participantIds = listOf("participant2") + ) + + // Verify threads were created + assertNotNull(thread1) + assertNotNull(thread2) + + // Verify thread1 is accessible in session1 + val retrievedThread1 = session1.getThread(thread1.id) + assertNotNull(retrievedThread1) + assertEquals("Thread in Session 1", retrievedThread1?.name) + + // Verify thread2 is accessible in session2 + val retrievedThread2 = session2.getThread(thread2.id) + assertNotNull(retrievedThread2) + assertEquals("Thread in Session 2", retrievedThread2?.name) + + // Verify thread1 is NOT accessible in session2 + val thread1InSession2 = session2.getThread(thread1.id) + assertNull(thread1InSession2) + + // Verify thread2 is NOT accessible in session1 + val thread2InSession1 = session1.getThread(thread2.id) + assertNull(thread2InSession1) + } + + @Test + fun `test agents are not available across sessions with different privacy keys`() { + // Create two sessions with different privacy keys + val session1 = sessionManager.createSessionWithId("session1", "app1", "key1") + val session2 = sessionManager.createSessionWithId("session2", "app1", "key2") + + // Register agents in both sessions with the same IDs + val agent1 = session1.registerAgent(agentId = "agent1") ?: throw AssertionError("could not register agent") + val agent2 = session1.registerAgent(agentId = "agent2") ?: throw AssertionError("could not register agent") + + val agent1InSession2 = session2.registerAgent(agentId = "agent1") ?: throw AssertionError("could not register agent") + val agent2InSession2 = session2.registerAgent(agentId = "agent2") ?: throw AssertionError("could not register agent") + + // Verify agents were registered in their respective sessions + val retrievedAgent1InSession1 = session1.getAgent(agent1.id) + assertNotNull(retrievedAgent1InSession1) + + val retrievedAgent1InSession2 = session2.getAgent(agent1InSession2.id) + assertNotNull(retrievedAgent1InSession2) + + // Verify agents in different sessions with the same ID are different objects + assertNotEquals(retrievedAgent1InSession1, retrievedAgent1InSession2) + + // Verify the count of agents in each session + assertEquals(2, session1.getAllAgents().size) + assertEquals(2, session2.getAllAgents().size) + + // Verify that creating a thread with agents from another session fails + val thread1 = session1.createThread( + name = "Thread in Session 1", + creatorId = "agent1", + participantIds = listOf("agent2") + ) + + assertNotNull(thread1) + + // Try to create a thread in session2 with agent1 from session1 as a participant + // This should still create the thread but only with valid participants from session2 + val thread2 = session2.createThread( + name = "Thread in Session 2", + creatorId = "agent1", + participantIds = listOf("nonexistent") + ) + + assertNotNull(thread2) + assertEquals(1, thread2.participants.size) // Only the creator should be included + } +} diff --git a/src/test/kotlin/org/coralprotocol/coralserver/session/SessionTest.kt b/src/test/kotlin/org/coralprotocol/coralserver/session/SessionTest.kt new file mode 100644 index 0000000000000000000000000000000000000000..85ec2a731add04c24e49bde85d32230ff8251aca --- /dev/null +++ b/src/test/kotlin/org/coralprotocol/coralserver/session/SessionTest.kt @@ -0,0 +1,390 @@ +package org.coralprotocol.coralserver.session + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.coralprotocol.coralserver.models.Agent +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import java.util.concurrent.atomic.AtomicBoolean + +class SessionTest { + private lateinit var session: CoralAgentGraphSession + + @BeforeEach + fun setup() { + // Create a new session for each test + session = CoralAgentGraphSession("test-session", "test-app", "test-key", agentGraph = null) + // Clear any existing data + session.clearAll() + } + + @Test + fun `test agent registration`() { + // Register a new agent + val agent = session.registerAgent("agent1") + + // Verify agent was registered + assertNotNull(agent) + assertEquals(agent, session.getAgent(agent!!.id)) + + // Try to register the same agent again + val duplicateAgent = session.registerAgent("agent1") + assertNull(duplicateAgent) + } + + @Test + fun `test agent registration with description`() { + // Register a new agent with description + val agent = session.registerAgent("agent2", "This agent is responsible for testing") + + // Verify agent was registered with description + assertNotNull(agent) + val retrievedAgent = session.getAgent(agent!!.id) + assertEquals(agent, retrievedAgent) + assertEquals("This agent is responsible for testing", retrievedAgent?.description) + } + + @Test + fun `test thread creation`() { + // Register agents + val creator = session.registerAgent("creator")!! + val participant1 = session.registerAgent("participant1")!! + val participant2 = session.registerAgent("participant2")!! + + // Create a thread + val thread = session.createThread( + name = "Test Thread", + creatorId = creator.id, + participantIds = listOf(participant1.id, participant2.id), + ) + + // Verify thread was created + assertNotNull(thread) + assertEquals("Test Thread", thread.name) + assertEquals("creator", thread.creatorId) + assertTrue(thread.participants.contains("creator")) + assertTrue(thread.participants.contains("participant1")) + assertTrue(thread.participants.contains("participant2")) + assertEquals(3, thread.participants.size) + } + + @Test + fun `test adding and removing participants`() { + // Register agents + val creator = session.registerAgent("creator") + val participant1 = session.registerAgent("participant1") + val participant2 = session.registerAgent("participant2") + val participant3 = session.registerAgent("participant3") + + // Create a thread + val thread = session.createThread( + name = "Test Thread", + creatorId = "creator", + participantIds = listOf("participant1") + ) + + // Add a participant + val addSuccess = session.addParticipantToThread( + threadId = thread.id ?: "", + participantId = "participant2" + ) + + // Verify participant was added + assertTrue(addSuccess) + val updatedThread = session.getThread(thread.id ?: "") + assertTrue(updatedThread?.participants?.contains("participant2") ?: false) + + // Remove a participant + val removeSuccess = session.removeParticipantFromThread( + threadId = thread.id ?: "", + participantId = "participant1" + ) + + // Verify participant was removed + assertTrue(removeSuccess) + val finalThread = session.getThread(thread.id ?: "") + assertFalse(finalThread?.participants?.contains("participant1") ?: true) + } + + @Test + fun `test sending messages and closing thread`() { + // Register agents + val creator = session.registerAgent("creator") + val participant = session.registerAgent("participant") + + // Create a thread + val thread = session.createThread( + name = "Test Thread", + creatorId = "creator", + participantIds = listOf("participant") + ) + + // Send a message + val message = session.sendMessage( + threadId = thread.id ?: "", + senderId = "creator", + content = "Hello, world!", + mentions = listOf("participant") + ) + + // Verify message was sent + assertNotNull(message) + assertEquals("Hello, world!", message.content) + assertEquals("creator", message.sender.id) + assertEquals(thread.id, message.thread.id) + assertTrue(message.mentions.contains("participant") ?: false) + + // Close the thread + val closeSuccess = session.closeThread( + threadId = thread.id ?: "", + summary = "Thread completed" + ) + + // Verify thread was closed + assertTrue(closeSuccess) + val closedThread = session.getThread(thread.id ?: "") + assertTrue(closedThread?.isClosed ?: false) + assertEquals("Thread completed", closedThread?.summary) + + // Try to send a message to a closed thread + assertThrows { + val failedMessage = session.sendMessage( + threadId = thread.id ?: "", + senderId = "creator", + content = "This should fail", + mentions = listOf() + ) + } + } + + @Test + fun `test waiting for mentions`() = runBlocking { + // Register agents + val creator = session.registerAgent("creator") + val participant = session.registerAgent("participant") + + // Create a thread + val thread = session.createThread( + name = "Test Thread", + creatorId = "creator", + participantIds = listOf("participant") + ) + + // Launch a coroutine to wait for mentions + val waitJob = launch(Dispatchers.Default) { + val messages = session.waitForMentions( + agentId = "participant", + timeoutMs = 5000 + ) + + // Verify messages were received + assertFalse(messages.isEmpty()) + assertEquals(1, messages.size) + assertEquals("Hello, participant!", messages[0].content) + } + + // Wait a bit to ensure the wait operation has started + delay(100) + + // Send a message with a mention + session.sendMessage( + threadId = thread?.id ?: "", + senderId = "creator", + content = "Hello, participant!", + mentions = listOf("participant") + ) + + // Wait for the job to complete + waitJob.join() + } + + @Test + fun `test waiting for mentions with timeout`() = runBlocking { + // Register an agent + val agent = session.registerAgent("agent") + + // Wait for mentions with a short timeout + val messages = session.waitForMentions( + agentId = "agent", + timeoutMs = 100 + ) + + // Verify no messages were received + assertTrue(messages.isEmpty()) + } + + @Test + fun `test listing all agents`() { + // Register multiple agents + val agent1 = session.registerAgent("agent1") + val agent2 = session.registerAgent("agent2") + val agent3 = session.registerAgent("agent3") + + // Get all agents + val agents = session.getAllAgents() + + // Verify all agents are returned + assertEquals(3, agents.size) + assertTrue(agents.contains(agent1)) + assertTrue(agents.contains(agent2)) + assertTrue(agents.contains(agent3)) + } + + @Test + fun `test waiting for agent count`() = runBlocking { + // Register some agents + val agent1 = session.registerAgent("agent1") + val agent2 = session.registerAgent("agent2") + + // Verify current count + assertEquals(2, session.getRegisteredAgentsCount()) + + // Launch a coroutine to wait for more agents + val waitJob = launch(Dispatchers.Default) { + val result = session.waitForAgentCount( + targetCount = 3, + timeoutMs = 5000 + ) + + // Verify wait was successful + assertTrue(result) + assertEquals(3, session.getRegisteredAgentsCount()) + } + + // Wait a bit to ensure the wait operation has started + delay(100) + + // Register another agent + val agent3 = session.registerAgent("agent3") + + // Wait for the job to complete + waitJob.join() + } + + @Test + fun `test waiting for agent count with timeout`() = runBlocking { + // Register some agents + val agent1 = session.registerAgent("agent1") + + // Wait for more agents with a short timeout + val result = session.waitForAgentCount( + targetCount = 3, + timeoutMs = 100 + ) + + // Verify wait timed out + assertFalse(result) + assertEquals(1, session.getRegisteredAgentsCount()) + } + + @Test + fun `test get threads for agent`() { + // Register agents + val creator = session.registerAgent("creator") + val participant1 = session.registerAgent("participant1") + val participant2 = session.registerAgent("participant2") + + // Create threads + val thread1 = session.createThread( + name = "Thread 1", + creatorId = "creator", + participantIds = listOf("participant1") + ) + + val thread2 = session.createThread( + name = "Thread 2", + creatorId = "creator", + participantIds = listOf("participant1", "participant2") + ) + + val thread3 = session.createThread( + name = "Thread 3", + creatorId = "participant2", + participantIds = listOf("creator") + ) + + // Get threads for participant1 + val threadsForParticipant1 = session.getThreadsForAgent("participant1") + + // Verify correct threads are returned + assertEquals(2, threadsForParticipant1.size) + assertTrue(threadsForParticipant1.contains(thread1)) + assertTrue(threadsForParticipant1.contains(thread2)) + assertFalse(threadsForParticipant1.contains(thread3)) + + // Get threads for participant2 + val threadsForParticipant2 = session.getThreadsForAgent("participant2") + + // Verify correct threads are returned + assertEquals(2, threadsForParticipant2.size) + assertTrue(threadsForParticipant2.contains(thread2)) + assertTrue(threadsForParticipant2.contains(thread3)) + assertFalse(threadsForParticipant2.contains(thread1)) + } + + @Test + fun `test multiple connections from same client with waitForAgents`() = runBlocking { + // Set the required agent count + session.devRequiredAgentStartCount = 3 + + // Create flags to track when each agent is registered + val agent1Registered = AtomicBoolean(false) + val agent2Registered = AtomicBoolean(false) + val agent3Registered = AtomicBoolean(false) + + // Launch 3 coroutines to simulate 3 concurrent connections + val connectionJobs = List(3) { index -> + launch(Dispatchers.IO) { + // Simulate a delay between connections + delay(100L * index) + + // Create an agent for this connection + val agentId = "agent-${index + 1}" + val agent = session.registerAgent(agentId) + + // Set the flag for this agent + when (index) { + 0 -> agent1Registered.set(true) + 1 -> agent2Registered.set(true) + 2 -> agent3Registered.set(true) + } + + // If this is the first or second agent, wait for all agents to be registered + if (index < 2) { + println("[DEBUG_LOG] Agent $agentId waiting for all agents to be registered") + val result = session.waitForAgentCount( + targetCount = 3, + timeoutMs = 5000 + ) + println("[DEBUG_LOG] Agent $agentId wait result: $result") + + // Verify wait was successful + assertTrue(result, "Agent $agentId wait should succeed") + assertEquals(3, session.getRegisteredAgentsCount(), "All 3 agents should be registered") + + // Verify all agents are registered + assertTrue(agent1Registered.get(), "Agent 1 should be registered") + assertTrue(agent2Registered.get(), "Agent 2 should be registered") + assertTrue(agent3Registered.get(), "Agent 3 should be registered") + } + } + } + + // Wait for all connections to complete + connectionJobs.forEach { it.join() } + + // Verify that all 3 agents are registered + assertEquals(3, session.getRegisteredAgentsCount(), "All 3 agents should be registered") + + // Verify that all 3 agents are in the session + val agents = session.getAllAgents() + assertEquals(3, agents.size, "Session should have 3 registered agents") + assertTrue(agents.any { it.id == "agent-1" }, "Agent 1 should be registered") + assertTrue(agents.any { it.id == "agent-2" }, "Agent 2 should be registered") + assertTrue(agents.any { it.id == "agent-3" }, "Agent 3 should be registered") + } +}