# FastAPI inference server with quantized model support from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import TextIteratorStreamer import torch import threading import time import uuid import os from dotenv import load_dotenv import json import json load_dotenv() # Environment-driven configuration MODEL_PATH = os.getenv("MODEL_PATH", "./models/mistral-finetuned-mk") HOST = os.getenv("HOST", "0.0.0.0") PORT = int(os.getenv("PORT", "8000")) ALLOW_ORIGINS = [o.strip() for o in os.getenv("ALLOW_ORIGINS", "*").split(",") if o.strip()] # Quantization / precision toggles LOAD_IN_4BIT = os.getenv("LOAD_IN_4BIT", "false").lower() == "true" LOAD_IN_8BIT = os.getenv("LOAD_IN_8BIT", "false").lower() == "true" TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true" TORCH_DTYPE = os.getenv("TORCH_DTYPE", "float16").lower() # float16|bfloat16|float32 _DTYPE_MAP = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } torch_dtype = _DTYPE_MAP.get(TORCH_DTYPE, torch.float16) app = FastAPI() # Enable CORS for simple web UIs and external callers app.add_middleware( CORSMiddleware, allow_origins=ALLOW_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) model = None tokenizer = None def ensure_model_loaded(): global model, tokenizer if model is not None and tokenizer is not None: return print("⏳ Loading model...") model_load_kwargs = { "device_map": "auto", "trust_remote_code": TRUST_REMOTE_CODE, } if LOAD_IN_4BIT: model_load_kwargs.update({"load_in_4bit": True}) elif LOAD_IN_8BIT: model_load_kwargs.update({"load_in_8bit": True}) else: model_load_kwargs.update({"torch_dtype": torch_dtype}) if not os.path.exists(MODEL_PATH) and not MODEL_PATH.count("/"): raise RuntimeError(f"Model path '{MODEL_PATH}' not found. Set MODEL_PATH to a valid directory or HF repo id.") model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, **model_load_kwargs) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=TRUST_REMOTE_CODE) print("✅ Model loaded successfully!") class GenerateRequest(BaseModel): prompt: str max_new_tokens: int = 128 temperature: float = 0.7 top_p: float = 0.9 repetition_penalty: float = 1.1 stream: bool = False @app.post("/generate") def generate(req: GenerateRequest): ensure_model_loaded() inputs = tokenizer(req.prompt, return_tensors="pt") def stream_tokens(): with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, repetition_penalty=req.repetition_penalty, do_sample=True, ) text = tokenizer.decode(outputs[0], skip_special_tokens=True) yield text if req.stream: return StreamingResponse(stream_tokens(), media_type="text/plain") # Non-streaming JSON response full_text = next(stream_tokens()) return {"response": full_text} # === OpenAI-compatible schemas === class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str | None = None messages: list[ChatMessage] temperature: float = 1.0 top_p: float = 1.0 max_tokens: int = 256 stream: bool = False stop: list[str] | None = None class CompletionRequest(BaseModel): model: str | None = None prompt: str temperature: float = 1.0 top_p: float = 1.0 max_tokens: int = 256 stream: bool = False stop: list[str] | None = None def build_prompt_from_messages(messages: list[ChatMessage]) -> str: # Prefer tokenizer chat template if available try: formatted = tokenizer.apply_chat_template( [m.dict() for m in messages], tokenize=False, add_generation_prompt=True, ) if isinstance(formatted, str) and formatted.strip(): return formatted except Exception: pass # Fallback simple format lines = [] for m in messages: prefix = "Корисник:" if m.role == "user" else ("Асистент:" if m.role == "assistant" else "Систем:") lines.append(f"{prefix} {m.content}") lines.append("Асистент:") return "\n".join(lines) def sse_pack(data: dict) -> str: return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" @app.post("/v1/completions") def completions(req: CompletionRequest): ensure_model_loaded() input_text = req.prompt inputs = tokenizer(input_text, return_tensors="pt") gen_kwargs = dict( max_new_tokens=req.max_tokens, temperature=req.temperature, top_p=req.top_p, do_sample=True, ) request_id = f"cmpl-{uuid.uuid4().hex[:24]}" model_name = os.getenv("MODEL_ID", "mk-llm") created = int(time.time()) if req.stream: def event_stream(): streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) thread = threading.Thread(target=model.generate, kwargs={**inputs, **gen_kwargs, "streamer": streamer}) thread.start() # Initial role-less delta (text completions don't send role) first = { "id": request_id, "object": "text_completion.chunk", "created": created, "model": model_name, "choices": [{"text": "", "index": 0, "finish_reason": None}], } yield sse_pack(first) for token_text in streamer: chunk = { "id": request_id, "object": "text_completion.chunk", "created": created, "model": model_name, "choices": [{"text": token_text, "index": 0, "finish_reason": None}], } yield sse_pack(chunk) yield "data: [DONE]\n\n" return StreamingResponse(event_stream(), media_type="text/event-stream") with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) text = tokenizer.decode(outputs[0], skip_special_tokens=True) prompt_tokens = inputs["input_ids"].shape[-1] completion_tokens = tokenizer(text, return_tensors="pt")["input_ids"].shape[-1] return { "id": request_id, "object": "text_completion", "created": created, "model": model_name, "choices": [{"text": text, "index": 0, "finish_reason": "stop"}], "usage": { "prompt_tokens": int(prompt_tokens), "completion_tokens": int(completion_tokens), "total_tokens": int(prompt_tokens + completion_tokens), }, } @app.post("/v1/chat/completions") def chat_completions(req: ChatCompletionRequest): ensure_model_loaded() prompt = build_prompt_from_messages(req.messages) inputs = tokenizer(prompt, return_tensors="pt") gen_kwargs = dict( max_new_tokens=req.max_tokens, temperature=req.temperature, top_p=req.top_p, do_sample=True, ) request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" model_name = os.getenv("MODEL_ID", "mk-llm") created = int(time.time()) if req.stream: def event_stream(): streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) thread = threading.Thread(target=model.generate, kwargs={**inputs, **gen_kwargs, "streamer": streamer}) thread.start() # Initial role delta first_chunk = { "id": request_id, "object": "chat.completion.chunk", "created": created, "model": model_name, "choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}], } yield sse_pack(first_chunk) for token_text in streamer: chunk = { "id": request_id, "object": "chat.completion.chunk", "created": created, "model": model_name, "choices": [{"delta": {"content": token_text}, "index": 0, "finish_reason": None}], } yield sse_pack(chunk) yield "data: [DONE]\n\n" return StreamingResponse(event_stream(), media_type="text/event-stream") with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) text = tokenizer.decode(outputs[0], skip_special_tokens=True) prompt_tokens = inputs["input_ids"].shape[-1] completion_tokens = tokenizer(text, return_tensors="pt")["input_ids"].shape[-1] return { "id": request_id, "object": "chat.completion", "created": created, "model": model_name, "choices": [ { "index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop", } ], "usage": { "prompt_tokens": int(prompt_tokens), "completion_tokens": int(completion_tokens), "total_tokens": int(prompt_tokens + completion_tokens), }, } @app.get("/v1/models") def list_models(): created = int(time.time()) return { "object": "list", "data": [ { "id": "mk-llm", "object": "model", "created": created, "owned_by": "community", } ], } if __name__ == "__main__": import uvicorn uvicorn.run(app, host=HOST, port=PORT)