GradLLM / app.py
johnbridges's picture
.
514fb81
raw
history blame
5.31 kB
# app.py
import asyncio
from contextlib import asynccontextmanager
import gradio as gr
from fastapi import FastAPI
from config import settings
from rabbit_base import RabbitBase
from listener import RabbitListenerBase
from rabbit_repo import RabbitRepo
from service import LLMService
from runners.base import ILLMRunner
# =========================
# @spaces.GPU() SECTION
# =========================
# This trivial GPU endpoint keeps ZeroGPU Spaces alive at startup.
try:
import spaces
ZERO_GPU_AVAILABLE = True
@spaces.GPU() # keep it trivial (no tensor allocations)
def gpu_ready_probe() -> str:
"""
Minimal GPU-decorated function so ZeroGPU detects a GPU entrypoint.
It's also referenced by a Gradio button and a FastAPI route below.
"""
return "gpu-probe-ok"
except Exception:
ZERO_GPU_AVAILABLE = False
# Fallback for local/CPU-only runs (same signature)
def gpu_ready_probe() -> str:
return "cpu-only"
# ---------------- Runner factory (stub) ----------------
class EchoRunner(ILLMRunner):
Type = "EchoRunner"
async def StartProcess(self, llmServiceObj: dict): # noqa: N802
pass
async def RemoveProcess(self, sessionId: str): # noqa: N802
pass
async def StopRequest(self, sessionId: str): # noqa: N802
pass
async def SendInputAndGetResponse(self, llmServiceObj: dict): # noqa: N802
pass
async def runner_factory(llmServiceObj: dict) -> ILLMRunner:
return EchoRunner()
# ---------------- Publisher and Service ----------------
publisher = RabbitRepo(external_source="https://space.external")
service = LLMService(publisher, runner_factory)
# ---------------- Handlers (.NET FuncName -> service) ----------------
async def h_start(data): await service.StartProcess(data or {})
async def h_user(data): await service.UserInput(data or {})
async def h_remove(data): await service.RemoveSession(data or {})
async def h_stop(data): await service.StopRequest(data or {})
async def h_qir(data): await service.QueryIndexResult(data or {})
async def h_getreg(_): await service.GetFunctionRegistry(False)
async def h_getreg_f(_): await service.GetFunctionRegistry(True)
handlers = {
"llmStartSession": h_start,
"llmUserInput": h_user,
"llmRemoveSession": h_remove,
"llmStopRequest": h_stop,
"queryIndexResult": h_qir,
"getFunctionRegistry": h_getreg,
"getFunctionRegistryFiltered": h_getreg_f,
}
# ---------------- Listener wiring ----------------
base = RabbitBase()
listener = RabbitListenerBase(
base,
instance_name=settings.RABBIT_INSTANCE_NAME, # queue prefix like your .NET instance
handlers=handlers,
)
# Declarations mirror your C# InitRabbitMQObjs()
DECLS = [
{"ExchangeName": f"llmStartSession{settings.SERVICE_ID}", "FuncName": "llmStartSession",
"MessageTimeout": 600_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
{"ExchangeName": f"llmUserInput{settings.SERVICE_ID}", "FuncName": "llmUserInput",
"MessageTimeout": 600_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
{"ExchangeName": f"llmRemoveSession{settings.SERVICE_ID}", "FuncName": "llmRemoveSession",
"MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
{"ExchangeName": f"llmStopRequest{settings.SERVICE_ID}", "FuncName": "llmStopRequest",
"MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
{"ExchangeName": f"queryIndexResult{settings.SERVICE_ID}", "FuncName": "queryIndexResult",
"MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
{"ExchangeName": f"getFunctionRegistry{settings.SERVICE_ID}", "FuncName": "getFunctionRegistry",
"MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
{"ExchangeName": f"getFunctionRegistryFiltered{settings.SERVICE_ID}", "FuncName": "getFunctionRegistryFiltered",
"MessageTimeout": 60_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
]
# ---------------- Gradio UI (smoke test + GPU probe) ----------------
async def ping():
return "ok"
with gr.Blocks() as demo:
gr.Markdown("### LLM Runner (Python) — RabbitMQ listener")
with gr.Row():
btn = gr.Button("Ping")
out = gr.Textbox(label="Ping result")
btn.click(ping, inputs=None, outputs=out)
# IMPORTANT: reference the decorated function DIRECTLY (no lambda)
if ZERO_GPU_AVAILABLE:
probe_btn = gr.Button("GPU Probe")
probe_out = gr.Textbox(label="GPU Probe Result")
probe_btn.click(gpu_ready_probe, None, probe_out)
# ---------------- FastAPI + lifespan ----------------
@asynccontextmanager
async def lifespan(_app: FastAPI):
# startup
await publisher.connect()
await service.init()
await listener.start(DECLS)
yield
# shutdown (optional)
# await publisher.close()
# await listener.stop()
app = FastAPI(lifespan=lifespan)
app = gr.mount_gradio_app(app, demo, path="/")
@app.get("/health")
async def health():
return {"status": "ok"}
# Also expose probe via HTTP (belt & braces for ZeroGPU detectors)
@app.get("/gpu-probe")
def gpu_probe_route():
return {"status": gpu_ready_probe()}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)