#include "socket_registry.h" #include #include #include #include "threads.h" namespace { MUTEX_T socketRegistryLock; std::map > socketsPerProxyConnection; } void InitWebSocketRegistry() { CREATE_MUTEX(&socketRegistryLock); } void TrackSocketUsedByConnection(int proxyConnection, SOCKET_T usedSocket) { if (usedSocket == 0) return; if (IsSocketPartOfConnection(proxyConnection, usedSocket)) return; LOCK_MUTEX(&socketRegistryLock); socketsPerProxyConnection[proxyConnection].push_back(usedSocket); UNLOCK_MUTEX(&socketRegistryLock); } void CloseSocketByConnection(int proxyConnection, SOCKET_T usedSocket) { if (!IsSocketPartOfConnection(proxyConnection, usedSocket)) return; printf("Closing socket fd %d used by proxy connection %d\n", (int)usedSocket, proxyConnection); LOCK_MUTEX(&socketRegistryLock); CLOSE_SOCKET(usedSocket); std::vector &sockets = socketsPerProxyConnection[proxyConnection]; sockets.erase(std::remove(sockets.begin(), sockets.end(), usedSocket), sockets.end()); UNLOCK_MUTEX(&socketRegistryLock); } void CloseAllSocketsByConnection(int proxyConnection) { LOCK_MUTEX(&socketRegistryLock); std::vector &sockets = socketsPerProxyConnection[proxyConnection]; for (size_t i = 0; i < sockets.size(); ++i) { printf("Closing socket fd %d used by proxy connection %d.\n", (int)sockets[i], proxyConnection); shutdown(sockets[i], SHUTDOWN_BIDIRECTIONAL); CLOSE_SOCKET(sockets[i]); } socketsPerProxyConnection.erase(proxyConnection); UNLOCK_MUTEX(&socketRegistryLock); } bool IsSocketPartOfConnection(int proxyConnection, SOCKET_T usedSocket) { bool result; if (usedSocket == 0) return true; // Allow all proxy connections to access "socket 0" when/if they need to refer to socket that does not exist. LOCK_MUTEX(&socketRegistryLock); if (socketsPerProxyConnection.find(proxyConnection) == socketsPerProxyConnection.end()) { UNLOCK_MUTEX(&socketRegistryLock); return false; } std::vector &sockets = socketsPerProxyConnection[proxyConnection]; result = std::find(sockets.begin(), sockets.end(), usedSocket) != sockets.end(); UNLOCK_MUTEX(&socketRegistryLock); return result; }