import argparse import json import sys import time from datetime import UTC, datetime from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path import sentencepiece as spm import torch sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) from sovyn import SovynConfig, SovynForCausalLM from sovyn.formatting import format_prompt from chat import clean_answer, score_answer def now_iso(): return datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z") class SovynRuntime: def __init__(self, args): self.model_name = args.model_name self.max_new_tokens = args.max_new_tokens self.temperature = args.temperature self.top_k = args.top_k self.best_of = args.best_of self.checkpoint_path = Path(args.checkpoint) device = args.device if device == "cuda" and not torch.cuda.is_available(): device = "cpu" self.device = device self.tokenizer = spm.SentencePieceProcessor(model_file=args.tokenizer) checkpoint = torch.load(self.checkpoint_path, map_location="cpu") model_cfg = checkpoint["config"]["model"] self.model = SovynForCausalLM(SovynConfig(**model_cfg)) self.model.load_state_dict(checkpoint["model"]) dtype = torch.bfloat16 if device == "cuda" else torch.float32 self.model.to(device=device, dtype=dtype) self.model.eval() self.eos_id = self.tokenizer.piece_to_id("") self.stop_ids = [ self.tokenizer.piece_to_id(piece) for piece in ["", "", "", "", "", ""] if self.tokenizer.piece_to_id(piece) >= 0 ] self.suppress_ids = [ idx for idx in [ self.tokenizer.piece_to_id(""), self.tokenizer.piece_to_id(""), self.tokenizer.piece_to_id(""), ] if idx >= 0 ] @torch.no_grad() def reply(self, user: str, system: str | None = None, options: dict | None = None) -> str: options = options or {} temperature = float(options.get("temperature", self.temperature)) top_k = int(options.get("top_k", self.top_k)) max_new_tokens = int(options.get("num_predict", self.max_new_tokens)) best_of = max(1, int(options.get("best_of", self.best_of))) runs = best_of if temperature > 0 else 1 prompt = format_prompt(user, system=system) ids = torch.tensor( [self.tokenizer.encode(prompt, out_type=int)], dtype=torch.long, device=self.device, ) candidates = [] for _ in range(runs): out = self.model.generate( ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, eos_id=self.eos_id, stop_ids=self.stop_ids, suppress_ids=self.suppress_ids, ) answer = clean_answer(self.tokenizer.decode(out[0].tolist())) candidates.append(answer) return max(candidates, key=lambda answer: score_answer(user, answer)) def tags(self): size = self.checkpoint_path.stat().st_size if self.checkpoint_path.exists() else 0 return { "models": [ { "name": self.model_name, "model": self.model_name, "modified_at": now_iso(), "size": size, "digest": "sovyn-local-pytorch", "details": { "parent_model": "", "format": "pytorch", "family": "sovyn", "families": ["sovyn"], "parameter_size": "300M", "quantization_level": "BF16", }, } ] } def json_bytes(payload: dict) -> bytes: return json.dumps(payload, ensure_ascii=False).encode("utf-8") def get_last_user_and_system(messages: list[dict]) -> tuple[str, str | None]: system = None user = "" for message in messages: role = message.get("role") content = message.get("content", "") if role == "system" and content: system = content elif role == "user" and content: user = content return user, system def make_handler(runtime: SovynRuntime): class Handler(BaseHTTPRequestHandler): server_version = "SOVYN-Ollama-Bridge/0.1" def log_message(self, fmt, *args): sys.stdout.write("%s - %s\n" % (self.address_string(), fmt % args)) sys.stdout.flush() def send_json(self, status: int, payload: dict): body = json_bytes(payload) self.send_response(status) self.send_header("Content-Type", "application/json; charset=utf-8") self.send_header("Content-Length", str(len(body))) self.end_headers() self.wfile.write(body) def send_stream_json(self, payload: dict): body = json_bytes(payload) + b"\n" self.send_response(200) self.send_header("Content-Type", "application/x-ndjson; charset=utf-8") self.end_headers() self.wfile.write(body) def read_payload(self) -> dict: length = int(self.headers.get("Content-Length", "0")) if length <= 0: return {} raw = self.rfile.read(length).decode("utf-8") return json.loads(raw) if raw else {} def do_GET(self): if self.path == "/" or self.path == "/api/version": self.send_json(200, {"version": "sovyn-ollama-bridge-0.1"}) elif self.path == "/api/tags": self.send_json(200, runtime.tags()) else: self.send_json(404, {"error": f"unknown route: {self.path}"}) def do_POST(self): started = time.perf_counter_ns() try: payload = self.read_payload() if self.path == "/api/generate": prompt = payload.get("prompt", "") options = payload.get("options") or {} answer = runtime.reply(prompt, options=options) response = { "model": runtime.model_name, "created_at": now_iso(), "response": answer, "done": True, "total_duration": time.perf_counter_ns() - started, } if payload.get("stream", True): self.send_stream_json(response) else: self.send_json(200, response) elif self.path == "/api/chat": user, system = get_last_user_and_system(payload.get("messages", [])) options = payload.get("options") or {} answer = runtime.reply(user, system=system, options=options) response = { "model": runtime.model_name, "created_at": now_iso(), "message": {"role": "assistant", "content": answer}, "done": True, "total_duration": time.perf_counter_ns() - started, } if payload.get("stream", True): self.send_stream_json(response) else: self.send_json(200, response) elif self.path == "/api/show": self.send_json( 200, { "modelfile": "FROM SOVYN PyTorch checkpoint via local bridge", "parameters": "temperature 0.7\ntop_k 20", "template": "{{ .Prompt }}", "details": runtime.tags()["models"][0]["details"], }, ) else: self.send_json(404, {"error": f"unknown route: {self.path}"}) except Exception as exc: self.send_json(500, {"error": str(exc)}) return Handler def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default="checkpoints/sovyn_300m_last.pt") parser.add_argument("--tokenizer", default="tokenizer_300m/sovyn.model") parser.add_argument("--model-name", default="sovyn:300m") parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=11434) parser.add_argument("--device", default="cuda") parser.add_argument("--max-new-tokens", type=int, default=64) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--top-k", type=int, default=0) parser.add_argument("--best-of", type=int, default=1) args = parser.parse_args() runtime = SovynRuntime(args) server = ThreadingHTTPServer((args.host, args.port), make_handler(runtime)) print(f"SOVYN Ollama-compatible API listening on http://{args.host}:{args.port}") print(f"model: {runtime.model_name}, device: {runtime.device}") server.serve_forever() if __name__ == "__main__": main()