north-air-api / app.py
arthu1's picture
v5.1: revert INT8 β€” too lossy for 0.6B, keep LoRA merge + float32
26344b5
import os
import re
import time
import json
from typing import List, Optional
from threading import Thread
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, TextIteratorStreamer
MODEL_DIR = os.getenv("MODEL_DIR", "./final_model")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.6"))
TOP_P = float(os.getenv("TOP_P", "0.85"))
SYSTEM_PROMPT = """You are North Air 1, built by North Air. 0.6B params, a custom model designed for helpful and concise responses.
Be direct, helpful, concise. Use markdown. Write clean code. Never fabricate facts.
If asked who you are: "I'm North Air 1, built by North Air." You are NOT ChatGPT/GPT-4/Claude/etc."""
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: Optional[str] = "north-air-1"
max_new_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
system_prompt: Optional[str] = None
stream: Optional[bool] = False
enable_thinking: Optional[bool] = False
app = FastAPI(title="North Air 1 API", version="5.0.0")
# ─── Model Loading ───
MODEL = None
TOKENIZER = None
LOAD_ERROR = None
INFERENCE_MODE = "pytorch"
def _load_model():
"""Load model β€” merge LoRA if present, then apply INT8 quantization."""
global MODEL, TOKENIZER, LOAD_ERROR, INFERENCE_MODE
try:
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True, trust_remote_code=True)
if TOKENIZER.pad_token is None:
TOKENIZER.pad_token = TOKENIZER.eos_token
except Exception as e:
LOAD_ERROR = f"Tokenizer load failed: {e}"
return
try:
from transformers import AutoModelForCausalLM
adapter_cfg = os.path.join(MODEL_DIR, "adapter_config.json")
if os.path.exists(adapter_cfg):
# Merge LoRA into base model at startup β€” removes adapter overhead
# and enables INT8 quantization (which breaks on un-merged PEFT models)
from peft import AutoPeftModelForCausalLM
print("Loading PEFT model + merging LoRA weights...")
t0 = time.time()
peft_model = AutoPeftModelForCausalLM.from_pretrained(
MODEL_DIR, torch_dtype=torch.float32, device_map={"": "cpu"},
)
MODEL = peft_model.merge_and_unload()
print(f"LoRA merged in {time.time() - t0:.1f}s")
else:
MODEL = AutoModelForCausalLM.from_pretrained(
MODEL_DIR, torch_dtype=torch.float32, device_map={"": "cpu"},
trust_remote_code=True,
)
MODEL.eval()
INFERENCE_MODE = "pytorch-merged"
print(f"Model ready: {INFERENCE_MODE}")
except Exception as e:
LOAD_ERROR = str(e)
print(f"Model load failed: {e}")
_load_model()
@app.get("/health")
def health():
return {
"ok": MODEL is not None,
"model": "north-air-1",
"version": "5.0.0",
"architecture": "Qwen3-0.6B + LoRA r=64 (merged)",
"inference": INFERENCE_MODE,
"features": ["streaming", "thinking", "merged"],
"model_dir": MODEL_DIR,
"error": LOAD_ERROR,
}
def _build_prompt(messages: list, system: str, enable_thinking: bool) -> str:
has_system = any(m["role"] == "system" for m in messages)
if not has_system:
messages = [{"role": "system", "content": system}] + messages
if hasattr(TOKENIZER, "apply_chat_template"):
return TOKENIZER.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=enable_thinking,
)
return "\n".join(f"{m['role']}: {m['content']}" for m in messages) + "\nassistant:"
def _parse_thinking(text: str) -> tuple:
think_match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
if think_match:
thinking = think_match.group(1).strip()
answer = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
return thinking, answer
return "", text
def _generation_kwargs(input_ids, attention_mask, max_new_tokens, temperature, top_p, **extra):
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_new_tokens,
"temperature": max(temperature, 0.01),
"top_p": top_p,
"top_k": 20,
"do_sample": True,
"repetition_penalty": 1.15,
"pad_token_id": TOKENIZER.pad_token_id,
"eos_token_id": TOKENIZER.eos_token_id,
**extra,
}
def _check_model():
if MODEL is None:
raise HTTPException(status_code=500, detail=f"Model failed to load: {LOAD_ERROR}")
if TOKENIZER is None:
raise HTTPException(status_code=500, detail=f"Tokenizer failed to load: {LOAD_ERROR}")
def _prepare_request(req: ChatRequest):
system = req.system_prompt or SYSTEM_PROMPT
messages = [{"role": m.role, "content": m.content} for m in req.messages]
enable_thinking = req.enable_thinking if req.enable_thinking is not None else False
prompt = _build_prompt(messages, system, enable_thinking)
batch = TOKENIZER(prompt, return_tensors="pt", add_special_tokens=False)
max_new_tokens = req.max_new_tokens or MAX_NEW_TOKENS
temperature = req.temperature if req.temperature is not None else TEMPERATURE
top_p = req.top_p if req.top_p is not None else TOP_P
return batch, max_new_tokens, temperature, top_p
@app.post("/chat")
def chat(req: ChatRequest):
_check_model()
if not req.messages:
raise HTTPException(status_code=400, detail="messages are required")
if req.stream:
return chat_stream(req)
batch, max_new_tokens, temperature, top_p = _prepare_request(req)
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
t0 = time.time()
with torch.no_grad():
out = MODEL.generate(
**_generation_kwargs(input_ids, attention_mask, max_new_tokens, temperature, top_p)
)
elapsed = time.time() - t0
generated_ids = out[0][input_ids.shape[1]:]
completion = TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip()
thinking, answer = _parse_thinking(completion)
return {
"output": answer,
"thinking": thinking if thinking else None,
"model": "north-air-1",
"inference": INFERENCE_MODE,
"tokens_generated": len(generated_ids),
"latency_ms": round(elapsed * 1000),
}
@app.post("/chat/stream")
def chat_stream(req: ChatRequest):
_check_model()
if not req.messages:
raise HTTPException(status_code=400, detail="messages are required")
batch, max_new_tokens, temperature, top_p = _prepare_request(req)
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
streamer = TextIteratorStreamer(TOKENIZER, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = _generation_kwargs(
input_ids, attention_mask, max_new_tokens, temperature, top_p,
streamer=streamer,
)
t0 = time.time()
thread = Thread(target=_generate_in_thread, args=(gen_kwargs,))
thread.start()
def event_stream():
token_count = 0
in_thinking = False
buf = ""
for token_text in streamer:
buf += token_text
token_count += 1
if "<think>" in buf and not in_thinking:
in_thinking = True
yield f"data: {json.dumps({'type': 'thinking_start'})}\n\n"
after = buf.split("<think>", 1)[1]
buf = after if after else ""
if "</think>" in buf and in_thinking:
before = buf.split("</think>", 1)[0]
if before:
yield f"data: {json.dumps({'type': 'thinking', 'text': before})}\n\n"
in_thinking = False
yield f"data: {json.dumps({'type': 'thinking_end'})}\n\n"
after = buf.split("</think>", 1)[1].lstrip()
buf = ""
if after:
yield f"data: {json.dumps({'type': 'text', 'text': after})}\n\n"
continue
partial_open = "<think"
partial_close = "</think"
if not in_thinking and buf.endswith(tuple(partial_open[:i] for i in range(1, len(partial_open) + 1))):
continue
if in_thinking and buf.endswith(tuple(partial_close[:i] for i in range(1, len(partial_close) + 1))):
continue
if buf:
evt_type = "thinking" if in_thinking else "text"
yield f"data: {json.dumps({'type': evt_type, 'text': buf})}\n\n"
buf = ""
if buf:
evt_type = "thinking" if in_thinking else "text"
yield f"data: {json.dumps({'type': evt_type, 'text': buf})}\n\n"
if in_thinking:
yield f"data: {json.dumps({'type': 'thinking_end'})}\n\n"
thread.join()
elapsed = time.time() - t0
yield f"data: {json.dumps({'type': 'done', 'tokens_generated': token_count, 'latency_ms': round(elapsed * 1000), 'inference': INFERENCE_MODE})}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
def _generate_in_thread(kwargs):
with torch.no_grad():
MODEL.generate(**kwargs)