compvis / mirothinker server.py
Dexter's picture
Upload mirothinker server.py
1f2c5db verified
Raw
History Blame Contribute Delete
3.91 kB
# Minimal OpenAI-compatible server for MiroThinker using plain transformers.
# No vLLM/SGLang needed. Works wherever PyTorch CUDA works (incl. Windows).
#
# pip install transformers accelerate fastapi uvicorn torch
# python mirothinker_server.py
#
# Then set in miroflow-agent .env / config:
# base_url: http://localhost:61005/v1 (provider "qwen", any api_key)
#
# Handles the two non-standard params miroflow-agent sends via extra_body:
# - repetition_penalty
# - continue_final_message / add_generation_prompt (resume truncated output)
import time
import uuid
import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "miromind-ai/MiroThinker-v1.5-30B" # change if using 1.7 etc.
PORT = 61005
print(f"Loading {MODEL_ID} ... (this takes a while)")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto", # splits layers across both A100s
trust_remote_code=True,
)
model.eval()
print("Model loaded. Device map:", model.hf_device_map if hasattr(model, "hf_device_map") else "n/a")
app = FastAPI()
class ChatRequest(BaseModel):
model: str = MODEL_ID
messages: list
max_tokens: int | None = 16384
max_completion_tokens: int | None = None
temperature: float = 1.0
top_p: float = 0.95
stream: bool = False
# extra_body params arrive as top-level fields with the openai SDK
repetition_penalty: float | None = None
continue_final_message: bool | None = None
add_generation_prompt: bool | None = None
class Config:
extra = "allow"
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/v1/models")
def models():
return {"object": "list", "data": [{"id": MODEL_ID, "object": "model"}]}
@app.post("/v1/chat/completions")
def chat(req: ChatRequest):
continue_final = bool(req.continue_final_message)
# Build the prompt with the model's own chat template.
# return_dict=True gives a BatchEncoding (input_ids + attention_mask);
# handle it as a dict so .shape / generate work correctly.
enc = tokenizer.apply_chat_template(
req.messages,
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
return_tensors="pt",
return_dict=True,
)
enc = {k: v.to(model.device) for k, v in enc.items()}
input_ids = enc["input_ids"]
prompt_tokens = input_ids.shape[-1]
max_new = req.max_completion_tokens or req.max_tokens or 16384
gen_kwargs = dict(
max_new_tokens=max_new,
do_sample=req.temperature > 0,
temperature=max(req.temperature, 1e-5),
top_p=req.top_p,
pad_token_id=tokenizer.eos_token_id,
)
if req.repetition_penalty and req.repetition_penalty != 1.0:
gen_kwargs["repetition_penalty"] = req.repetition_penalty
with torch.inference_mode():
out = model.generate(**enc, **gen_kwargs)
new_tokens = out[0][prompt_tokens:]
text = tokenizer.decode(new_tokens, skip_special_tokens=True)
completion_tokens = new_tokens.shape[-1]
finish = "length" if completion_tokens >= max_new else "stop"
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": finish,
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=PORT)