sovyn-300m-cortex / scripts /ollama_bridge.py
SOVYN's picture
Upload folder using huggingface_hub
681909f verified
import argparse
import json
import sys
import time
from datetime import UTC, datetime
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
import sentencepiece as spm
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from sovyn import SovynConfig, SovynForCausalLM
from sovyn.formatting import format_prompt
from chat import clean_answer, score_answer
def now_iso():
return datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z")
class SovynRuntime:
def __init__(self, args):
self.model_name = args.model_name
self.max_new_tokens = args.max_new_tokens
self.temperature = args.temperature
self.top_k = args.top_k
self.best_of = args.best_of
self.checkpoint_path = Path(args.checkpoint)
device = args.device
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
self.device = device
self.tokenizer = spm.SentencePieceProcessor(model_file=args.tokenizer)
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
model_cfg = checkpoint["config"]["model"]
self.model = SovynForCausalLM(SovynConfig(**model_cfg))
self.model.load_state_dict(checkpoint["model"])
dtype = torch.bfloat16 if device == "cuda" else torch.float32
self.model.to(device=device, dtype=dtype)
self.model.eval()
self.eos_id = self.tokenizer.piece_to_id("<eos>")
self.stop_ids = [
self.tokenizer.piece_to_id(piece)
for piece in ["<system>", "<user>", "<state>", "<plan>", "<memory>", "<reflection>"]
if self.tokenizer.piece_to_id(piece) >= 0
]
self.suppress_ids = [
idx
for idx in [
self.tokenizer.piece_to_id("<pad>"),
self.tokenizer.piece_to_id("<unk>"),
self.tokenizer.piece_to_id("<bos>"),
]
if idx >= 0
]
@torch.no_grad()
def reply(self, user: str, system: str | None = None, options: dict | None = None) -> str:
options = options or {}
temperature = float(options.get("temperature", self.temperature))
top_k = int(options.get("top_k", self.top_k))
max_new_tokens = int(options.get("num_predict", self.max_new_tokens))
best_of = max(1, int(options.get("best_of", self.best_of)))
runs = best_of if temperature > 0 else 1
prompt = format_prompt(user, system=system)
ids = torch.tensor(
[self.tokenizer.encode(prompt, out_type=int)],
dtype=torch.long,
device=self.device,
)
candidates = []
for _ in range(runs):
out = self.model.generate(
ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=self.eos_id,
stop_ids=self.stop_ids,
suppress_ids=self.suppress_ids,
)
answer = clean_answer(self.tokenizer.decode(out[0].tolist()))
candidates.append(answer)
return max(candidates, key=lambda answer: score_answer(user, answer))
def tags(self):
size = self.checkpoint_path.stat().st_size if self.checkpoint_path.exists() else 0
return {
"models": [
{
"name": self.model_name,
"model": self.model_name,
"modified_at": now_iso(),
"size": size,
"digest": "sovyn-local-pytorch",
"details": {
"parent_model": "",
"format": "pytorch",
"family": "sovyn",
"families": ["sovyn"],
"parameter_size": "300M",
"quantization_level": "BF16",
},
}
]
}
def json_bytes(payload: dict) -> bytes:
return json.dumps(payload, ensure_ascii=False).encode("utf-8")
def get_last_user_and_system(messages: list[dict]) -> tuple[str, str | None]:
system = None
user = ""
for message in messages:
role = message.get("role")
content = message.get("content", "")
if role == "system" and content:
system = content
elif role == "user" and content:
user = content
return user, system
def make_handler(runtime: SovynRuntime):
class Handler(BaseHTTPRequestHandler):
server_version = "SOVYN-Ollama-Bridge/0.1"
def log_message(self, fmt, *args):
sys.stdout.write("%s - %s\n" % (self.address_string(), fmt % args))
sys.stdout.flush()
def send_json(self, status: int, payload: dict):
body = json_bytes(payload)
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def send_stream_json(self, payload: dict):
body = json_bytes(payload) + b"\n"
self.send_response(200)
self.send_header("Content-Type", "application/x-ndjson; charset=utf-8")
self.end_headers()
self.wfile.write(body)
def read_payload(self) -> dict:
length = int(self.headers.get("Content-Length", "0"))
if length <= 0:
return {}
raw = self.rfile.read(length).decode("utf-8")
return json.loads(raw) if raw else {}
def do_GET(self):
if self.path == "/" or self.path == "/api/version":
self.send_json(200, {"version": "sovyn-ollama-bridge-0.1"})
elif self.path == "/api/tags":
self.send_json(200, runtime.tags())
else:
self.send_json(404, {"error": f"unknown route: {self.path}"})
def do_POST(self):
started = time.perf_counter_ns()
try:
payload = self.read_payload()
if self.path == "/api/generate":
prompt = payload.get("prompt", "")
options = payload.get("options") or {}
answer = runtime.reply(prompt, options=options)
response = {
"model": runtime.model_name,
"created_at": now_iso(),
"response": answer,
"done": True,
"total_duration": time.perf_counter_ns() - started,
}
if payload.get("stream", True):
self.send_stream_json(response)
else:
self.send_json(200, response)
elif self.path == "/api/chat":
user, system = get_last_user_and_system(payload.get("messages", []))
options = payload.get("options") or {}
answer = runtime.reply(user, system=system, options=options)
response = {
"model": runtime.model_name,
"created_at": now_iso(),
"message": {"role": "assistant", "content": answer},
"done": True,
"total_duration": time.perf_counter_ns() - started,
}
if payload.get("stream", True):
self.send_stream_json(response)
else:
self.send_json(200, response)
elif self.path == "/api/show":
self.send_json(
200,
{
"modelfile": "FROM SOVYN PyTorch checkpoint via local bridge",
"parameters": "temperature 0.7\ntop_k 20",
"template": "{{ .Prompt }}",
"details": runtime.tags()["models"][0]["details"],
},
)
else:
self.send_json(404, {"error": f"unknown route: {self.path}"})
except Exception as exc:
self.send_json(500, {"error": str(exc)})
return Handler
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="checkpoints/sovyn_300m_last.pt")
parser.add_argument("--tokenizer", default="tokenizer_300m/sovyn.model")
parser.add_argument("--model-name", default="sovyn:300m")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=11434)
parser.add_argument("--device", default="cuda")
parser.add_argument("--max-new-tokens", type=int, default=64)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--best-of", type=int, default=1)
args = parser.parse_args()
runtime = SovynRuntime(args)
server = ThreadingHTTPServer((args.host, args.port), make_handler(runtime))
print(f"SOVYN Ollama-compatible API listening on http://{args.host}:{args.port}")
print(f"model: {runtime.model_name}, device: {runtime.device}")
server.serve_forever()
if __name__ == "__main__":
main()