| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import time |
| import uuid |
|
|
| import torch |
| import uvicorn |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| MODEL_ID = "miromind-ai/MiroThinker-v1.5-30B" |
| PORT = 61005 |
|
|
| print(f"Loading {MODEL_ID} ... (this takes a while)") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| model.eval() |
| print("Model loaded. Device map:", model.hf_device_map if hasattr(model, "hf_device_map") else "n/a") |
|
|
| app = FastAPI() |
|
|
|
|
| class ChatRequest(BaseModel): |
| model: str = MODEL_ID |
| messages: list |
| max_tokens: int | None = 16384 |
| max_completion_tokens: int | None = None |
| temperature: float = 1.0 |
| top_p: float = 0.95 |
| stream: bool = False |
| |
| repetition_penalty: float | None = None |
| continue_final_message: bool | None = None |
| add_generation_prompt: bool | None = None |
|
|
| class Config: |
| extra = "allow" |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok"} |
|
|
|
|
| @app.get("/v1/models") |
| def models(): |
| return {"object": "list", "data": [{"id": MODEL_ID, "object": "model"}]} |
|
|
|
|
| @app.post("/v1/chat/completions") |
| def chat(req: ChatRequest): |
| continue_final = bool(req.continue_final_message) |
|
|
| |
| |
| |
| enc = tokenizer.apply_chat_template( |
| req.messages, |
| add_generation_prompt=not continue_final, |
| continue_final_message=continue_final, |
| return_tensors="pt", |
| return_dict=True, |
| ) |
| enc = {k: v.to(model.device) for k, v in enc.items()} |
| input_ids = enc["input_ids"] |
|
|
| prompt_tokens = input_ids.shape[-1] |
| max_new = req.max_completion_tokens or req.max_tokens or 16384 |
|
|
| gen_kwargs = dict( |
| max_new_tokens=max_new, |
| do_sample=req.temperature > 0, |
| temperature=max(req.temperature, 1e-5), |
| top_p=req.top_p, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| if req.repetition_penalty and req.repetition_penalty != 1.0: |
| gen_kwargs["repetition_penalty"] = req.repetition_penalty |
|
|
| with torch.inference_mode(): |
| out = model.generate(**enc, **gen_kwargs) |
|
|
| new_tokens = out[0][prompt_tokens:] |
| text = tokenizer.decode(new_tokens, skip_special_tokens=True) |
| completion_tokens = new_tokens.shape[-1] |
| finish = "length" if completion_tokens >= max_new else "stop" |
|
|
| return { |
| "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": req.model, |
| "choices": [{ |
| "index": 0, |
| "message": {"role": "assistant", "content": text}, |
| "finish_reason": finish, |
| }], |
| "usage": { |
| "prompt_tokens": prompt_tokens, |
| "completion_tokens": completion_tokens, |
| "total_tokens": prompt_tokens + completion_tokens, |
| }, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=PORT) |
|
|