File size: 3,914 Bytes
569d24c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f2c5db
 
 
569d24c
 
 
 
1f2c5db
 
 
 
569d24c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f2c5db
569d24c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Minimal OpenAI-compatible server for MiroThinker using plain transformers.
# No vLLM/SGLang needed. Works wherever PyTorch CUDA works (incl. Windows).
#
#   pip install transformers accelerate fastapi uvicorn torch
#   python mirothinker_server.py
#
# Then set in miroflow-agent .env / config:
#   base_url: http://localhost:61005/v1   (provider "qwen", any api_key)
#
# Handles the two non-standard params miroflow-agent sends via extra_body:
#   - repetition_penalty
#   - continue_final_message / add_generation_prompt (resume truncated output)

import time
import uuid

import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "miromind-ai/MiroThinker-v1.5-30B"  # change if using 1.7 etc.
PORT = 61005

print(f"Loading {MODEL_ID} ... (this takes a while)")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",          # splits layers across both A100s
    trust_remote_code=True,
)
model.eval()
print("Model loaded. Device map:", model.hf_device_map if hasattr(model, "hf_device_map") else "n/a")

app = FastAPI()


class ChatRequest(BaseModel):
    model: str = MODEL_ID
    messages: list
    max_tokens: int | None = 16384
    max_completion_tokens: int | None = None
    temperature: float = 1.0
    top_p: float = 0.95
    stream: bool = False
    # extra_body params arrive as top-level fields with the openai SDK
    repetition_penalty: float | None = None
    continue_final_message: bool | None = None
    add_generation_prompt: bool | None = None

    class Config:
        extra = "allow"


@app.get("/health")
def health():
    return {"status": "ok"}


@app.get("/v1/models")
def models():
    return {"object": "list", "data": [{"id": MODEL_ID, "object": "model"}]}


@app.post("/v1/chat/completions")
def chat(req: ChatRequest):
    continue_final = bool(req.continue_final_message)

    # Build the prompt with the model's own chat template.
    # return_dict=True gives a BatchEncoding (input_ids + attention_mask);
    # handle it as a dict so .shape / generate work correctly.
    enc = tokenizer.apply_chat_template(
        req.messages,
        add_generation_prompt=not continue_final,
        continue_final_message=continue_final,
        return_tensors="pt",
        return_dict=True,
    )
    enc = {k: v.to(model.device) for k, v in enc.items()}
    input_ids = enc["input_ids"]

    prompt_tokens = input_ids.shape[-1]
    max_new = req.max_completion_tokens or req.max_tokens or 16384

    gen_kwargs = dict(
        max_new_tokens=max_new,
        do_sample=req.temperature > 0,
        temperature=max(req.temperature, 1e-5),
        top_p=req.top_p,
        pad_token_id=tokenizer.eos_token_id,
    )
    if req.repetition_penalty and req.repetition_penalty != 1.0:
        gen_kwargs["repetition_penalty"] = req.repetition_penalty

    with torch.inference_mode():
        out = model.generate(**enc, **gen_kwargs)

    new_tokens = out[0][prompt_tokens:]
    text = tokenizer.decode(new_tokens, skip_special_tokens=True)
    completion_tokens = new_tokens.shape[-1]
    finish = "length" if completion_tokens >= max_new else "stop"

    return {
        "id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
        "object": "chat.completion",
        "created": int(time.time()),
        "model": req.model,
        "choices": [{
            "index": 0,
            "message": {"role": "assistant", "content": text},
            "finish_reason": finish,
        }],
        "usage": {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": prompt_tokens + completion_tokens,
        },
    }


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=PORT)