File size: 4,087 Bytes
e0552b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from koja_diffuser.runtime.inference import Inference
import asyncio
from contextlib import suppress
from typing import Any, Awaitable, AsyncGenerator, Literal
from pydantic import BaseModel, Field
import json


app = FastAPI()
inference = Inference(device="cpu")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:5173"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def make_sse(data: Any, event: str | None = None) -> str:
    payload = json.dumps(data, ensure_ascii=False)

    lines = []

    if event is not None:
        lines.append(f"event: {event}")

    for line in payload.splitlines():
        lines.append(f"data: {line}")

    return "\n".join(lines) + "\n\n"


class Emitter:
    def __init__(self) -> None:
        self.queue: asyncio.Queue[str | None] = asyncio.Queue()

    async def emit(self, event: str, data: Any) -> None:
        await self.queue.put(make_sse(data, event=event))

    async def error(self, data: str):
        await self.emit(
            "error",
            {
                "message": str(data),
            },
        )

    async def _run(self, awaitable: Awaitable[Any]) -> None:
        try:
            await awaitable

        except asyncio.CancelledError:
            raise

        except Exception as exc:
            await self.error(exc)

        finally:
            await self.queue.put(None)

    async def stream(
        self,
        awaitable: Awaitable[Any],
        *,
        request=None,
    ) -> AsyncGenerator[str, None]:
        task = asyncio.create_task(self._run(awaitable))

        try:
            while True:
                if request is not None and await request.is_disconnected():
                    task.cancel()
                    break

                try:
                    item = await asyncio.wait_for(
                        self.queue.get(),
                        timeout=0.5,
                    )
                except asyncio.TimeoutError:
                    continue

                if item is None:
                    break

                yield item

        finally:
            task.cancel()

            with suppress(asyncio.CancelledError):
                await task


async def generate():
    for i in range(5):
        yield f"data: chunk {i}\n\n"
        await asyncio.sleep(1)


class NameRequest(BaseModel):
    name: str = Field(min_length=1)
    age: int = Field(ge=0, le=9, strict=True)
    seed: int | None = Field(default=None, strict=True)
    sampling_mode: Literal["greedy", "sample"] = Field(default="sample")
    temperature: float = Field(default=0.8, ge=0.1, le=2.0, strict=True)
    top_k: int = Field(default=20, ge=1, le=100, strict=True)
    top_p: float = Field(default=0.9, ge=0.1, le=1.0, strict=True)


def detect_script(text: str) -> str:
    chars = [ch for ch in text if not ch.isspace() and ch.isalpha()]

    if not chars:
        return "unknown"

    if all(
        "\uac00" <= ch <= "\ud7a3"
        or "\u1100" <= ch <= "\u11ff"
        or "\u3130" <= ch <= "\u318f"
        for ch in chars
    ):
        return "hangul"

    if all("\u3040" <= ch <= "\u309f" for ch in chars):
        return "hiragana"

    return "mixed_or_other"


@app.post("/stream")
async def stream(request: Request, body: NameRequest):
    emitter = Emitter()

    generate = (
        inference.ja_to_ko
        if detect_script(body.name) == "hiragana"
        else inference.ko_to_ja
    )

    return StreamingResponse(
        emitter.stream(
            generate(
                [body.name],
                [body.age],
                seed=body.seed,
                sampling_mode=body.sampling_mode,
                temperature=body.temperature,
                top_k=body.top_k,
                top_p=body.top_p,
                emit=emitter.emit,
            ),
            request=request,
        ),
        media_type="text/event-stream",
    )