Spaces:
Sleeping
Sleeping
File size: 4,087 Bytes
e0552b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from koja_diffuser.runtime.inference import Inference
import asyncio
from contextlib import suppress
from typing import Any, Awaitable, AsyncGenerator, Literal
from pydantic import BaseModel, Field
import json
app = FastAPI()
inference = Inference(device="cpu")
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def make_sse(data: Any, event: str | None = None) -> str:
payload = json.dumps(data, ensure_ascii=False)
lines = []
if event is not None:
lines.append(f"event: {event}")
for line in payload.splitlines():
lines.append(f"data: {line}")
return "\n".join(lines) + "\n\n"
class Emitter:
def __init__(self) -> None:
self.queue: asyncio.Queue[str | None] = asyncio.Queue()
async def emit(self, event: str, data: Any) -> None:
await self.queue.put(make_sse(data, event=event))
async def error(self, data: str):
await self.emit(
"error",
{
"message": str(data),
},
)
async def _run(self, awaitable: Awaitable[Any]) -> None:
try:
await awaitable
except asyncio.CancelledError:
raise
except Exception as exc:
await self.error(exc)
finally:
await self.queue.put(None)
async def stream(
self,
awaitable: Awaitable[Any],
*,
request=None,
) -> AsyncGenerator[str, None]:
task = asyncio.create_task(self._run(awaitable))
try:
while True:
if request is not None and await request.is_disconnected():
task.cancel()
break
try:
item = await asyncio.wait_for(
self.queue.get(),
timeout=0.5,
)
except asyncio.TimeoutError:
continue
if item is None:
break
yield item
finally:
task.cancel()
with suppress(asyncio.CancelledError):
await task
async def generate():
for i in range(5):
yield f"data: chunk {i}\n\n"
await asyncio.sleep(1)
class NameRequest(BaseModel):
name: str = Field(min_length=1)
age: int = Field(ge=0, le=9, strict=True)
seed: int | None = Field(default=None, strict=True)
sampling_mode: Literal["greedy", "sample"] = Field(default="sample")
temperature: float = Field(default=0.8, ge=0.1, le=2.0, strict=True)
top_k: int = Field(default=20, ge=1, le=100, strict=True)
top_p: float = Field(default=0.9, ge=0.1, le=1.0, strict=True)
def detect_script(text: str) -> str:
chars = [ch for ch in text if not ch.isspace() and ch.isalpha()]
if not chars:
return "unknown"
if all(
"\uac00" <= ch <= "\ud7a3"
or "\u1100" <= ch <= "\u11ff"
or "\u3130" <= ch <= "\u318f"
for ch in chars
):
return "hangul"
if all("\u3040" <= ch <= "\u309f" for ch in chars):
return "hiragana"
return "mixed_or_other"
@app.post("/stream")
async def stream(request: Request, body: NameRequest):
emitter = Emitter()
generate = (
inference.ja_to_ko
if detect_script(body.name) == "hiragana"
else inference.ko_to_ja
)
return StreamingResponse(
emitter.stream(
generate(
[body.name],
[body.age],
seed=body.seed,
sampling_mode=body.sampling_mode,
temperature=body.temperature,
top_k=body.top_k,
top_p=body.top_p,
emit=emitter.emit,
),
request=request,
),
media_type="text/event-stream",
)
|