import os, time, threading
# ── MKL / OpenMP tuning BEFORE torch import ───────────────────────────────────
# On HF free CPU (1 vCPU), inter-op parallelism causes contention.
# MKL_NUM_THREADS=1 avoids spawning extra threads inside BLAS kernels.
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
# MKL-DNN (oneDNN) is the main CPU perf backend for PyTorch
os.environ.setdefault("DNNL_VERBOSE", "0")
import torch
import torch.backends.mkldnn
# Lock threads after env is set
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
from flask import Flask, request, jsonify, Response, send_from_directory, stream_with_context
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, QuantoConfig
app = Flask(__name__)
SYSTEM_PROMPT = (
"You are Cygnis-Alpha, a helpful AI assistant created by CygnisAI.\n\n"
"## ABSOLUTE RULES\n"
"1. LANGUAGE: Detect the user's language. You MUST reply 100% in the SAME language as the user's last message. Never switch to English unless the user asks you to.\n"
"2. IDENTITY: Your name is Cygnis-Alpha, created by CygnisAI.\n"
"3. HONESTY: Never invent facts. If you don't know, say it.\n"
"4. FOCUS: Answer only what was asked. No yapping.\n\n"
"## STYLE\n"
"- Tone: Warm, friendly, professional.\n"
"- Length: Be extremely concise (short answers). Only detail if explicitly requested."
)
MODEL_ID = "CygnisAI/Cygnis-Alpha-1.7B-v0.1-Instruct"
FAVICON_SVG = """
"""
HTML_PAGE = """
Cygnis-Alpha · Online
API Endpoints
POST /generate — JSON complet
POST /generate/stream — SSE token par token
Body: {"prompt": "...", "max_new_tokens": 256, "fast": true}
SSE: data: token · data: [STATS]... · data: [DONE]
⚡ INT8 quanto
"""
# ─── Model globals ─────────────────────────────────────────────────────────────
tokenizer_g = None
model_g = None
model_ready = False
model_error = None
quant_mode = "none" # updated after load
def load_model():
global tokenizer_g, model_g, model_ready, model_error, quant_mode
try:
print(f"[CygnisAI] Loading {MODEL_ID} ...")
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
# ── INT8 quantization via quanto ──────────────────────────────────────
# Quantizes linear layers to int8 weights → ~4× smaller, faster matmul
try:
qconfig = QuantoConfig(weights="int8")
mdl = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=qconfig,
low_cpu_mem_usage=True,
)
quant_mode = "int8-quanto"
print("[CygnisAI] ✅ INT8 quantization loaded.")
except Exception as qe:
print(f"[CygnisAI] INT8 failed ({qe}), falling back to float32 …")
mdl = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
quant_mode = "float32"
mdl.eval()
# ── torch.compile ─────────────────────────────────────────────────────
# "reduce-overhead" eliminates Python dispatch overhead per token
try:
mdl = torch.compile(mdl, mode="reduce-overhead", fullgraph=False)
print("[CygnisAI] torch.compile OK.")
except Exception as ce:
print(f"[CygnisAI] torch.compile skipped: {ce}")
tokenizer_g = tok
model_g = mdl
model_ready = True
print(f"[CygnisAI] Model ready. Mode: {quant_mode}")
# ── Warmup: trigger compile before first real request ─────────────────
_warmup()
except Exception as e:
model_error = str(e)
print(f"[CygnisAI] Load error: {e}")
def _warmup():
try:
print("[CygnisAI] Warming up ...")
ids = tokenizer_g("Hi", return_tensors="pt")
with torch.inference_mode():
model_g.generate(
**ids,
max_new_tokens=4,
do_sample=False,
use_cache=True,
pad_token_id=tokenizer_g.eos_token_id,
)
print("[CygnisAI] Warmup done — ready to serve.")
except Exception as e:
print(f"[CygnisAI] Warmup error (non-fatal): {e}")
threading.Thread(target=load_model, daemon=True).start()
# ─── Helpers ──────────────────────────────────────────────────────────────────
def build_prompt(user_prompt: str) -> str:
return (
f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
def guard():
if model_error:
return False, (jsonify({"error": f"Model failed to load: {model_error}"}), 500)
if not model_ready:
return False, (jsonify({"error": "Model is loading, retry in a moment."}), 503)
return True, None
def parse_body():
data = request.get_json(silent=True)
if not data:
return None, None, None, None, (jsonify({"error": "Request body must be valid JSON."}), 400)
prompt = str(data.get("prompt", "")).strip()
if not prompt:
return None, None, None, None, (jsonify({"error": "Field 'prompt' is required."}), 400)
max_tok = min(int(data.get("max_new_tokens", 256)), 512)
temperature = float(data.get("temperature", 0.7))
fast = bool(data.get("fast", True)) # True = greedy (~2× faster)
return prompt, max_tok, temperature, fast, None
def make_gen_kwargs(inputs, max_tok, temperature, fast, streamer=None):
kw = dict(
**inputs,
max_new_tokens=max_tok,
use_cache=True,
pad_token_id=tokenizer_g.eos_token_id,
eos_token_id=tokenizer_g.eos_token_id,
)
if fast:
kw["do_sample"] = False # greedy: fastest
else:
kw.update(do_sample=True, temperature=temperature, top_p=0.9, repetition_penalty=1.15)
if streamer:
kw["streamer"] = streamer
return kw
# ─── Routes ───────────────────────────────────────────────────────────────────
@app.route("/favicon.svg")
def favicon_svg():
return Response(FAVICON_SVG, mimetype="image/svg+xml")
@app.route("/favicon.ico")
@app.route("/favicon.png")
def favicon_fallback():
root = os.path.dirname(os.path.abspath(__file__))
for name in ("favicon.png", "favicon.ico"):
if os.path.exists(os.path.join(root, name)):
return send_from_directory(root, name)
return Response(FAVICON_SVG, mimetype="image/svg+xml")
@app.route("/", methods=["GET"])
def home():
return Response(HTML_PAGE, mimetype="text/html")
@app.route("/health", methods=["GET"])
def health():
if model_error:
return jsonify({"status": "error", "detail": model_error}), 500
if not model_ready:
return jsonify({"status": "loading"}), 503
return jsonify({"status": "ok", "model": MODEL_ID, "quant": quant_mode})
# ── /generate ─────────────────────────────────────────────────────────────────
@app.route("/generate", methods=["POST"])
def generate():
ok, err = guard()
if not ok: return err
prompt, max_tok, temperature, fast, err = parse_body()
if err: return err
inputs = tokenizer_g(build_prompt(prompt), return_tensors="pt")
kwargs = make_gen_kwargs(inputs, max_tok, temperature, fast)
n_prompt = inputs["input_ids"].shape[-1]
result = {}
def _infer():
try:
with torch.inference_mode():
out = model_g.generate(**kwargs)
new_ids = out[0][n_prompt:]
text = tokenizer_g.decode(new_ids, skip_special_tokens=True)
result["text"] = text.split("<|im_end|>")[0].strip() or "Je suis Cygnis-Alpha."
result["n_tokens"] = len(new_ids)
except Exception as e:
result["error"] = str(e)
t0 = time.time()
t = threading.Thread(target=_infer)
t.start(); t.join(timeout=120)
elapsed = round(time.time() - t0, 2)
if t.is_alive():
return jsonify({"error": "Timeout >120s. Reduce max_new_tokens."}), 504
if "error" in result:
return jsonify({"error": result["error"]}), 500
n = result.get("n_tokens", 0)
tps = round(n / elapsed, 2) if elapsed > 0 else 0
return jsonify({
"response": result["text"],
"model": MODEL_ID, "quant": quant_mode,
"elapsed_sec": elapsed, "tokens": n, "tps": tps,
})
# ── /generate/stream (SSE) ───────────────────────────────────────────────────
@app.route("/generate/stream", methods=["POST"])
def generate_stream():
ok, err = guard()
if not ok: return err
prompt, max_tok, temperature, fast, err = parse_body()
if err: return err
inputs = tokenizer_g(build_prompt(prompt), return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer_g, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
kwargs = make_gen_kwargs(inputs, max_tok, temperature, fast, streamer=streamer)
gen_err = {}
def _run():
try:
with torch.inference_mode():
model_g.generate(**kwargs)
except Exception as e:
gen_err["msg"] = str(e)
def event_stream():
t0, n = time.time(), 0
t = threading.Thread(target=_run, daemon=True)
t.start()
try:
for token in streamer:
clean = token.replace("<|im_end|>", "")
if clean:
n += 1
yield f"data: {clean}\n\n"
except Exception as e:
yield f"data: [ERROR] {e}\n\n"
finally:
t.join(timeout=5)
elapsed = round(time.time() - t0, 2)
tps = round(n / elapsed, 2) if elapsed > 0 else 0
if gen_err:
yield f"data: [ERROR] {gen_err['msg']}\n\n"
yield f"data: [STATS] tokens={n} elapsed={elapsed}s tps={tps} quant={quant_mode}\n\n"
yield "data: [DONE]\n\n"
return Response(
stream_with_context(event_stream()),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"},
)
# ─── Entry ────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False, threaded=True)