RubiNet / ministral_3b_hmc_server.py
DevHunterAI's picture
Upload ministral_3b_hmc_server.py with huggingface_hub
6a90c85 verified
raw
history blame
20.8 kB
import argparse
import ast
import json
import operator
import re
import threading
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlparse
from ministral_3b_hmc_chat import (
DEFAULT_ADAPTER_DIR,
DEFAULT_MODEL_ID,
SYSTEM_PROMPT,
build_prompt,
generate_reply,
load_model,
)
SERVER_VERSION = "ministral-hmc-server-2026-03-22-v1"
MATH_SYSTEM_PROMPT = "You are RubiNet. Solve math problems carefully and step by step. Verify arithmetic before answering. Keep the reasoning concise but clear, and end with 'Final answer: ...'."
MATH_KEYWORDS = (
"calculate",
"compute",
"evaluate",
"solve",
"equation",
"math",
"algebra",
"geometry",
"probability",
"percentage",
"percent",
"sum",
"product",
"difference",
"quotient",
)
ALLOWED_CALC_NODES = {
ast.Expression,
ast.BinOp,
ast.UnaryOp,
ast.Constant,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.FloorDiv,
ast.Mod,
ast.Pow,
ast.USub,
ast.UAdd,
}
CALC_BIN_OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
}
CALC_UNARY_OPS = {
ast.UAdd: operator.pos,
ast.USub: operator.neg,
}
HTML_PAGE = """<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>RubiNet Chat</title>
<style>
body { font-family: Arial, sans-serif; margin: 0; background: #111827; color: #f3f4f6; }
.wrap { max-width: 960px; margin: 0 auto; padding: 24px; }
.card { background: #1f2937; border-radius: 16px; padding: 20px; box-shadow: 0 10px 30px rgba(0,0,0,.25); }
h1 { margin-top: 0; font-size: 28px; }
.meta { color: #9ca3af; margin-bottom: 10px; }
#chat { min-height: 360px; max-height: 60vh; overflow-y: auto; padding: 12px; background: #0f172a; border-radius: 12px; margin-bottom: 16px; }
.msg { padding: 12px 14px; border-radius: 12px; margin-bottom: 12px; white-space: pre-wrap; word-break: break-word; }
.user { background: #2563eb; }
.bot { background: #374151; }
form { display: flex; gap: 12px; align-items: stretch; }
textarea { flex: 1; min-height: 96px; max-height: 240px; resize: vertical; border-radius: 12px; border: none; padding: 12px; font: inherit; }
button { border: none; border-radius: 12px; padding: 0 20px; background: #10b981; color: white; font-weight: 700; cursor: pointer; }
button:disabled { background: #6b7280; cursor: wait; }
.status { margin-top: 12px; color: #93c5fd; min-height: 24px; }
.controls { display: flex; gap: 12px; align-items: center; flex-wrap: wrap; margin-bottom: 16px; }
.controls select { border-radius: 12px; border: none; padding: 12px 14px; font: inherit; background: #e5e7eb; color: #111827; }
.secondary { background: #7c3aed; }
.danger { background: #dc2626; }
</style>
</head>
<body>
<div class="wrap">
<div class="card">
<h1>RubiNet Local Chat</h1>
<div class="meta">Version: <code>__VERSION__</code></div>
<div class="meta">Model: <code>__MODEL__</code></div>
<div class="meta">Adapter: <code>__ADAPTER__</code></div>
<div class="controls">
<select id="voice-gender">
<option value="female">Female voice</option>
<option value="male">Male voice</option>
</select>
<select id="speech-language">
<option value="tr-TR">Turkish speech</option>
<option value="en-US">English speech</option>
</select>
<button id="activate-voice" class="secondary" type="button">Activate voice</button>
<button id="stop-voice" class="danger" type="button">Stop voice</button>
</div>
<div id="chat"></div>
<form id="chat-form">
<textarea id="message" placeholder="Type your message..."></textarea>
<button id="send" type="submit">Send</button>
</form>
<div class="status" id="status"></div>
</div>
</div>
<script>
const form = document.getElementById('chat-form');
const messageEl = document.getElementById('message');
const chatEl = document.getElementById('chat');
const statusEl = document.getElementById('status');
const sendEl = document.getElementById('send');
const voiceGenderEl = document.getElementById('voice-gender');
const speechLanguageEl = document.getElementById('speech-language');
const activateVoiceEl = document.getElementById('activate-voice');
const stopVoiceEl = document.getElementById('stop-voice');
const SpeechRecognition = window.SpeechRecognition || window.webkitSpeechRecognition;
const speechSupported = !!SpeechRecognition && 'speechSynthesis' in window;
let recognition = null;
let listening = false;
let voices = [];
let restartingRecognition = false;
let audioStream = null;
function addMessage(role, text) {
const div = document.createElement('div');
div.className = `msg ${role}`;
div.textContent = text;
chatEl.appendChild(div);
chatEl.scrollTop = chatEl.scrollHeight;
}
function updateRecognitionLanguage() {
if (recognition) {
recognition.lang = speechLanguageEl.value || 'en-US';
}
}
async function ensureMicrophoneAccess() {
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
throw new Error('This browser does not support microphone access.');
}
if (audioStream) {
return audioStream;
}
audioStream = await navigator.mediaDevices.getUserMedia({
audio: {
echoCancellation: true,
noiseSuppression: true,
autoGainControl: true,
}
});
return audioStream;
}
async function startListening() {
if (!recognition) return;
try {
await ensureMicrophoneAccess();
updateRecognitionLanguage();
recognition.start();
listening = true;
restartingRecognition = false;
statusEl.textContent = 'Listening...';
activateVoiceEl.textContent = 'Listening';
} catch (error) {
statusEl.textContent = `Voice activation failed: ${error.message || error}`;
}
}
function loadVoices() {
voices = window.speechSynthesis ? window.speechSynthesis.getVoices() : [];
}
function pickVoice(gender) {
const normalizedGender = gender === 'male' ? 'male' : 'female';
const preferred = normalizedGender === 'male'
? ['male', 'david', 'mark', 'guy', 'james', 'richard', 'george', 'microsoft david', 'microsoft mark']
: ['female', 'zira', 'hazel', 'aria', 'jenny', 'susan', 'sara', 'microsoft zira', 'microsoft aria'];
const lowered = voices.map((voice) => ({ voice, name: `${voice.name} ${voice.voiceURI}`.toLowerCase() }));
for (const token of preferred) {
const match = lowered.find((item) => item.name.includes(token));
if (match) return match.voice;
}
const trMatch = voices.find((voice) => /tr|turkish/i.test(`${voice.lang} ${voice.name}`));
return trMatch || voices[0] || null;
}
function speakReply(text) {
if (!window.speechSynthesis || !text) return;
window.speechSynthesis.cancel();
const utterance = new SpeechSynthesisUtterance(text);
utterance.lang = speechLanguageEl.value || 'en-US';
utterance.rate = 1;
utterance.pitch = voiceGenderEl.value === 'male' ? 0.85 : 1.1;
const voice = pickVoice(voiceGenderEl.value);
if (voice) {
utterance.voice = voice;
utterance.lang = voice.lang || utterance.lang;
}
window.speechSynthesis.speak(utterance);
}
async function sendMessage(message, shouldSpeak = false) {
addMessage('user', message);
statusEl.textContent = 'Generating reply...';
sendEl.disabled = true;
activateVoiceEl.disabled = true;
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 90000);
try {
const response = await fetch('/chat', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ message }),
signal: controller.signal
});
const data = await response.json();
if (!response.ok) {
addMessage('bot', data.error || 'Unknown error');
return;
}
addMessage('bot', data.reply);
if (shouldSpeak) {
speakReply(data.reply);
}
} catch (error) {
if (error && error.name === 'AbortError') {
addMessage('bot', 'The request timed out. Try a shorter message.');
} else {
addMessage('bot', `Request failed: ${error}`);
}
} finally {
clearTimeout(timeoutId);
statusEl.textContent = listening ? 'Listening...' : '';
sendEl.disabled = false;
activateVoiceEl.disabled = !speechSupported;
messageEl.focus();
}
}
form.addEventListener('submit', async (event) => {
event.preventDefault();
const message = messageEl.value.trim();
if (!message) return;
messageEl.value = '';
await sendMessage(message, false);
});
if (speechSupported) {
loadVoices();
window.speechSynthesis.onvoiceschanged = loadVoices;
recognition = new SpeechRecognition();
recognition.lang = speechLanguageEl.value || 'en-US';
recognition.continuous = false;
recognition.interimResults = false;
recognition.maxAlternatives = 1;
recognition.onstart = () => {
restartingRecognition = false;
statusEl.textContent = 'Listening...';
};
recognition.onresult = async (event) => {
const result = event.results[event.results.length - 1];
if (!result || !result.isFinal) return;
const transcript = result[0].transcript.trim();
if (!transcript) return;
await sendMessage(transcript, true);
};
recognition.onend = () => {
if (listening && !restartingRecognition) {
restartingRecognition = true;
setTimeout(() => {
if (!listening) {
restartingRecognition = false;
return;
}
startListening();
}, 350);
}
};
recognition.onerror = (event) => {
if (event.error === 'no-speech') {
statusEl.textContent = 'No speech detected. Keep speaking closer to the microphone...';
if (listening && !restartingRecognition) {
restartingRecognition = true;
try {
recognition.stop();
} catch (error) {
}
}
return;
}
if (event.error === 'not-allowed' || event.error === 'service-not-allowed') {
listening = false;
activateVoiceEl.textContent = 'Activate voice';
statusEl.textContent = 'Microphone permission was denied. Please allow microphone access in your browser.';
return;
}
if (event.error === 'audio-capture') {
listening = false;
activateVoiceEl.textContent = 'Activate voice';
statusEl.textContent = 'No microphone was found, or another application is using it.';
return;
}
statusEl.textContent = `Voice listening error: ${event.error}`;
};
speechLanguageEl.addEventListener('change', () => {
updateRecognitionLanguage();
if (listening) {
statusEl.textContent = 'Listening...';
}
});
activateVoiceEl.addEventListener('click', async () => {
if (listening) return;
await startListening();
});
stopVoiceEl.addEventListener('click', () => {
listening = false;
restartingRecognition = false;
activateVoiceEl.textContent = 'Activate voice';
statusEl.textContent = '';
window.speechSynthesis.cancel();
if (recognition) {
recognition.stop();
}
if (audioStream) {
for (const track of audioStream.getTracks()) {
track.stop();
}
audioStream = null;
}
});
} else {
activateVoiceEl.disabled = true;
stopVoiceEl.disabled = true;
voiceGenderEl.disabled = true;
speechLanguageEl.disabled = true;
statusEl.textContent = 'This browser does not support speech recognition or speech synthesis.';
}
</script>
</body>
</html>
"""
def looks_like_math_query(message: str) -> bool:
normalized = message.strip().lower()
if not normalized:
return False
if re.search(r"\d", normalized) and re.search(r"[+\-*/=^×÷%]", normalized):
return True
return any(keyword in normalized for keyword in MATH_KEYWORDS)
def extract_simple_expression(message: str) -> tuple[str, str] | None:
normalized = message.strip()
normalized = re.sub(r"(?i)\bwhat is\b", "", normalized)
normalized = re.sub(r"(?i)\bcalculate\b", "", normalized)
normalized = re.sub(r"(?i)\bcompute\b", "", normalized)
normalized = re.sub(r"(?i)\bevaluate\b", "", normalized)
normalized = re.sub(r"(?i)\bsolve\b", "", normalized)
normalized = normalized.replace("×", "*").replace("÷", "/").replace("^", "**")
normalized = normalized.replace("=?", "").replace("= ?", "").replace("=", "")
normalized = normalized.replace("?", "").strip()
if not normalized:
return None
if not re.fullmatch(r"[0-9\s\.+\-*/()%]*", normalized):
return None
if not re.search(r"\d", normalized) or not re.search(r"[+\-*/%()]", normalized):
return None
compact = re.sub(r"\s+", "", normalized)
return normalized, compact
def _eval_calc_node(node):
if type(node) not in ALLOWED_CALC_NODES:
raise ValueError("Unsupported expression.")
if isinstance(node, ast.Expression):
return _eval_calc_node(node.body)
if isinstance(node, ast.Constant):
if not isinstance(node.value, (int, float)):
raise ValueError("Unsupported constant.")
return float(node.value)
if isinstance(node, ast.UnaryOp):
op_type = type(node.op)
if op_type not in CALC_UNARY_OPS:
raise ValueError("Unsupported unary operator.")
return CALC_UNARY_OPS[op_type](_eval_calc_node(node.operand))
if isinstance(node, ast.BinOp):
op_type = type(node.op)
if op_type not in CALC_BIN_OPS:
raise ValueError("Unsupported binary operator.")
left = _eval_calc_node(node.left)
right = _eval_calc_node(node.right)
return CALC_BIN_OPS[op_type](left, right)
raise ValueError("Unsupported expression.")
def evaluate_simple_expression(expression: str) -> str:
parsed = ast.parse(expression, mode="eval")
value = _eval_calc_node(parsed)
if isinstance(value, float) and value.is_integer():
return str(int(value))
return f"{value:.12g}"
class MinistralHMCService:
def __init__(self, model_id: str, adapter_dir: str, system_prompt: str, max_new_tokens: int, temperature: float, top_p: float, use_4bit: bool, cpu_dtype: str, offload_folder: str):
self.model_id = model_id
self.adapter_dir = adapter_dir
self.system_prompt = system_prompt
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
self.use_4bit = use_4bit
self.cpu_dtype = cpu_dtype
self.offload_folder = offload_folder
self.tokenizer = None
self.model = None
self._generation_lock = threading.Lock()
def load(self):
self.tokenizer, self.model = load_model(
self.model_id,
self.adapter_dir,
self.use_4bit,
self.cpu_dtype,
self.offload_folder,
)
def reply(self, message: str) -> str:
with self._generation_lock:
simple_expression = extract_simple_expression(message)
if simple_expression is not None:
original_expression, compact_expression = simple_expression
exact_answer = evaluate_simple_expression(compact_expression)
return f"Expression: {original_expression}\nVerified result: {exact_answer}\nFinal answer: {exact_answer}"
system_prompt = MATH_SYSTEM_PROMPT if looks_like_math_query(message) else self.system_prompt
prompt = build_prompt(message, system_prompt)
return generate_reply(
self.tokenizer,
self.model,
prompt,
self.max_new_tokens,
self.temperature,
self.top_p,
)
class ChatHandler(BaseHTTPRequestHandler):
service = None
def _send_json(self, payload, status=HTTPStatus.OK):
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def _send_html(self, html: str):
body = html.encode("utf-8")
self.send_response(HTTPStatus.OK)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def do_GET(self):
path = urlparse(self.path).path
if path == "/health":
self._send_json({"status": "ok", "model": self.service.model_id, "adapter": self.service.adapter_dir, "version": SERVER_VERSION})
return
if path != "/":
self.send_error(HTTPStatus.NOT_FOUND)
return
page = HTML_PAGE.replace("__VERSION__", SERVER_VERSION)
page = page.replace("__MODEL__", self.service.model_id)
page = page.replace("__ADAPTER__", self.service.adapter_dir)
self._send_html(page)
def do_POST(self):
if urlparse(self.path).path != "/chat":
self.send_error(HTTPStatus.NOT_FOUND)
return
try:
content_length = int(self.headers.get("Content-Length", "0"))
body = self.rfile.read(content_length)
data = json.loads(body.decode("utf-8"))
message = str(data.get("message", "")).strip()
if not message:
self._send_json({"error": "Message cannot be empty."}, status=HTTPStatus.BAD_REQUEST)
return
reply = self.service.reply(message)
self._send_json({"reply": reply})
except Exception as exc:
self._send_json({"error": str(exc)}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
def log_message(self, format, *args):
return
def main():
parser = argparse.ArgumentParser(description="Serve Ministral 3B HMC on a local web server")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8036)
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
parser.add_argument("--adapter-dir", default=DEFAULT_ADAPTER_DIR)
parser.add_argument("--system-prompt", default=SYSTEM_PROMPT)
parser.add_argument("--max-new-tokens", type=int, default=32)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--use-4bit", action="store_true")
parser.add_argument("--cpu-dtype", choices=["float32", "float16", "bfloat16"], default="bfloat16")
parser.add_argument("--offload-folder", default=r"C:\Users\ASUS\CascadeProjects\.hf-offload")
args = parser.parse_args()
service = MinistralHMCService(
model_id=args.model_id,
adapter_dir=args.adapter_dir,
system_prompt=args.system_prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
use_4bit=args.use_4bit,
cpu_dtype=args.cpu_dtype,
offload_folder=args.offload_folder,
)
print("Loading Ministral 3B HMC model...")
service.load()
print(f"Ministral 3B HMC server ready at http://{args.host}:{args.port}")
ChatHandler.service = service
server = ThreadingHTTPServer((args.host, args.port), ChatHandler)
try:
server.serve_forever()
except KeyboardInterrupt:
pass
finally:
server.server_close()
if __name__ == "__main__":
main()