| import argparse |
| import json |
| import sys |
| import time |
| from datetime import UTC, datetime |
| from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer |
| from pathlib import Path |
|
|
| import sentencepiece as spm |
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) |
|
|
| from sovyn import SovynConfig, SovynForCausalLM |
| from sovyn.formatting import format_prompt |
|
|
| from chat import clean_answer, score_answer |
|
|
|
|
| def now_iso(): |
| return datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z") |
|
|
|
|
| class SovynRuntime: |
| def __init__(self, args): |
| self.model_name = args.model_name |
| self.max_new_tokens = args.max_new_tokens |
| self.temperature = args.temperature |
| self.top_k = args.top_k |
| self.best_of = args.best_of |
| self.checkpoint_path = Path(args.checkpoint) |
|
|
| device = args.device |
| if device == "cuda" and not torch.cuda.is_available(): |
| device = "cpu" |
| self.device = device |
|
|
| self.tokenizer = spm.SentencePieceProcessor(model_file=args.tokenizer) |
| checkpoint = torch.load(self.checkpoint_path, map_location="cpu") |
| model_cfg = checkpoint["config"]["model"] |
| self.model = SovynForCausalLM(SovynConfig(**model_cfg)) |
| self.model.load_state_dict(checkpoint["model"]) |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
| self.model.to(device=device, dtype=dtype) |
| self.model.eval() |
|
|
| self.eos_id = self.tokenizer.piece_to_id("<eos>") |
| self.stop_ids = [ |
| self.tokenizer.piece_to_id(piece) |
| for piece in ["<system>", "<user>", "<state>", "<plan>", "<memory>", "<reflection>"] |
| if self.tokenizer.piece_to_id(piece) >= 0 |
| ] |
| self.suppress_ids = [ |
| idx |
| for idx in [ |
| self.tokenizer.piece_to_id("<pad>"), |
| self.tokenizer.piece_to_id("<unk>"), |
| self.tokenizer.piece_to_id("<bos>"), |
| ] |
| if idx >= 0 |
| ] |
|
|
| @torch.no_grad() |
| def reply(self, user: str, system: str | None = None, options: dict | None = None) -> str: |
| options = options or {} |
| temperature = float(options.get("temperature", self.temperature)) |
| top_k = int(options.get("top_k", self.top_k)) |
| max_new_tokens = int(options.get("num_predict", self.max_new_tokens)) |
| best_of = max(1, int(options.get("best_of", self.best_of))) |
| runs = best_of if temperature > 0 else 1 |
|
|
| prompt = format_prompt(user, system=system) |
| ids = torch.tensor( |
| [self.tokenizer.encode(prompt, out_type=int)], |
| dtype=torch.long, |
| device=self.device, |
| ) |
|
|
| candidates = [] |
| for _ in range(runs): |
| out = self.model.generate( |
| ids, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| eos_id=self.eos_id, |
| stop_ids=self.stop_ids, |
| suppress_ids=self.suppress_ids, |
| ) |
| answer = clean_answer(self.tokenizer.decode(out[0].tolist())) |
| candidates.append(answer) |
| return max(candidates, key=lambda answer: score_answer(user, answer)) |
|
|
| def tags(self): |
| size = self.checkpoint_path.stat().st_size if self.checkpoint_path.exists() else 0 |
| return { |
| "models": [ |
| { |
| "name": self.model_name, |
| "model": self.model_name, |
| "modified_at": now_iso(), |
| "size": size, |
| "digest": "sovyn-local-pytorch", |
| "details": { |
| "parent_model": "", |
| "format": "pytorch", |
| "family": "sovyn", |
| "families": ["sovyn"], |
| "parameter_size": "300M", |
| "quantization_level": "BF16", |
| }, |
| } |
| ] |
| } |
|
|
|
|
| def json_bytes(payload: dict) -> bytes: |
| return json.dumps(payload, ensure_ascii=False).encode("utf-8") |
|
|
|
|
| def get_last_user_and_system(messages: list[dict]) -> tuple[str, str | None]: |
| system = None |
| user = "" |
| for message in messages: |
| role = message.get("role") |
| content = message.get("content", "") |
| if role == "system" and content: |
| system = content |
| elif role == "user" and content: |
| user = content |
| return user, system |
|
|
|
|
| def make_handler(runtime: SovynRuntime): |
| class Handler(BaseHTTPRequestHandler): |
| server_version = "SOVYN-Ollama-Bridge/0.1" |
|
|
| def log_message(self, fmt, *args): |
| sys.stdout.write("%s - %s\n" % (self.address_string(), fmt % args)) |
| sys.stdout.flush() |
|
|
| def send_json(self, status: int, payload: dict): |
| body = json_bytes(payload) |
| self.send_response(status) |
| self.send_header("Content-Type", "application/json; charset=utf-8") |
| self.send_header("Content-Length", str(len(body))) |
| self.end_headers() |
| self.wfile.write(body) |
|
|
| def send_stream_json(self, payload: dict): |
| body = json_bytes(payload) + b"\n" |
| self.send_response(200) |
| self.send_header("Content-Type", "application/x-ndjson; charset=utf-8") |
| self.end_headers() |
| self.wfile.write(body) |
|
|
| def read_payload(self) -> dict: |
| length = int(self.headers.get("Content-Length", "0")) |
| if length <= 0: |
| return {} |
| raw = self.rfile.read(length).decode("utf-8") |
| return json.loads(raw) if raw else {} |
|
|
| def do_GET(self): |
| if self.path == "/" or self.path == "/api/version": |
| self.send_json(200, {"version": "sovyn-ollama-bridge-0.1"}) |
| elif self.path == "/api/tags": |
| self.send_json(200, runtime.tags()) |
| else: |
| self.send_json(404, {"error": f"unknown route: {self.path}"}) |
|
|
| def do_POST(self): |
| started = time.perf_counter_ns() |
| try: |
| payload = self.read_payload() |
| if self.path == "/api/generate": |
| prompt = payload.get("prompt", "") |
| options = payload.get("options") or {} |
| answer = runtime.reply(prompt, options=options) |
| response = { |
| "model": runtime.model_name, |
| "created_at": now_iso(), |
| "response": answer, |
| "done": True, |
| "total_duration": time.perf_counter_ns() - started, |
| } |
| if payload.get("stream", True): |
| self.send_stream_json(response) |
| else: |
| self.send_json(200, response) |
| elif self.path == "/api/chat": |
| user, system = get_last_user_and_system(payload.get("messages", [])) |
| options = payload.get("options") or {} |
| answer = runtime.reply(user, system=system, options=options) |
| response = { |
| "model": runtime.model_name, |
| "created_at": now_iso(), |
| "message": {"role": "assistant", "content": answer}, |
| "done": True, |
| "total_duration": time.perf_counter_ns() - started, |
| } |
| if payload.get("stream", True): |
| self.send_stream_json(response) |
| else: |
| self.send_json(200, response) |
| elif self.path == "/api/show": |
| self.send_json( |
| 200, |
| { |
| "modelfile": "FROM SOVYN PyTorch checkpoint via local bridge", |
| "parameters": "temperature 0.7\ntop_k 20", |
| "template": "{{ .Prompt }}", |
| "details": runtime.tags()["models"][0]["details"], |
| }, |
| ) |
| else: |
| self.send_json(404, {"error": f"unknown route: {self.path}"}) |
| except Exception as exc: |
| self.send_json(500, {"error": str(exc)}) |
|
|
| return Handler |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="checkpoints/sovyn_300m_last.pt") |
| parser.add_argument("--tokenizer", default="tokenizer_300m/sovyn.model") |
| parser.add_argument("--model-name", default="sovyn:300m") |
| parser.add_argument("--host", default="127.0.0.1") |
| parser.add_argument("--port", type=int, default=11434) |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument("--max-new-tokens", type=int, default=64) |
| parser.add_argument("--temperature", type=float, default=0.0) |
| parser.add_argument("--top-k", type=int, default=0) |
| parser.add_argument("--best-of", type=int, default=1) |
| args = parser.parse_args() |
|
|
| runtime = SovynRuntime(args) |
| server = ThreadingHTTPServer((args.host, args.port), make_handler(runtime)) |
| print(f"SOVYN Ollama-compatible API listening on http://{args.host}:{args.port}") |
| print(f"model: {runtime.model_name}, device: {runtime.device}") |
| server.serve_forever() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|