|
|
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from pydantic import BaseModel
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
from transformers import TextIteratorStreamer
|
|
|
import torch
|
|
|
import threading
|
|
|
import time
|
|
|
import uuid
|
|
|
import os
|
|
|
from dotenv import load_dotenv
|
|
|
import json
|
|
|
import json
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
|
MODEL_PATH = os.getenv("MODEL_PATH", "./models/mistral-finetuned-mk")
|
|
|
HOST = os.getenv("HOST", "0.0.0.0")
|
|
|
PORT = int(os.getenv("PORT", "8000"))
|
|
|
ALLOW_ORIGINS = [o.strip() for o in os.getenv("ALLOW_ORIGINS", "*").split(",") if o.strip()]
|
|
|
|
|
|
|
|
|
LOAD_IN_4BIT = os.getenv("LOAD_IN_4BIT", "false").lower() == "true"
|
|
|
LOAD_IN_8BIT = os.getenv("LOAD_IN_8BIT", "false").lower() == "true"
|
|
|
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true"
|
|
|
TORCH_DTYPE = os.getenv("TORCH_DTYPE", "float16").lower()
|
|
|
|
|
|
_DTYPE_MAP = {
|
|
|
"float16": torch.float16,
|
|
|
"bfloat16": torch.bfloat16,
|
|
|
"float32": torch.float32,
|
|
|
}
|
|
|
torch_dtype = _DTYPE_MAP.get(TORCH_DTYPE, torch.float16)
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=ALLOW_ORIGINS,
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
model = None
|
|
|
tokenizer = None
|
|
|
|
|
|
def ensure_model_loaded():
|
|
|
global model, tokenizer
|
|
|
if model is not None and tokenizer is not None:
|
|
|
return
|
|
|
print("⏳ Loading model...")
|
|
|
model_load_kwargs = {
|
|
|
"device_map": "auto",
|
|
|
"trust_remote_code": TRUST_REMOTE_CODE,
|
|
|
}
|
|
|
if LOAD_IN_4BIT:
|
|
|
model_load_kwargs.update({"load_in_4bit": True})
|
|
|
elif LOAD_IN_8BIT:
|
|
|
model_load_kwargs.update({"load_in_8bit": True})
|
|
|
else:
|
|
|
model_load_kwargs.update({"torch_dtype": torch_dtype})
|
|
|
|
|
|
if not os.path.exists(MODEL_PATH) and not MODEL_PATH.count("/"):
|
|
|
raise RuntimeError(f"Model path '{MODEL_PATH}' not found. Set MODEL_PATH to a valid directory or HF repo id.")
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, **model_load_kwargs)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=TRUST_REMOTE_CODE)
|
|
|
print("✅ Model loaded successfully!")
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
|
prompt: str
|
|
|
max_new_tokens: int = 128
|
|
|
temperature: float = 0.7
|
|
|
top_p: float = 0.9
|
|
|
repetition_penalty: float = 1.1
|
|
|
stream: bool = False
|
|
|
|
|
|
|
|
|
@app.post("/generate")
|
|
|
def generate(req: GenerateRequest):
|
|
|
ensure_model_loaded()
|
|
|
inputs = tokenizer(req.prompt, return_tensors="pt")
|
|
|
|
|
|
def stream_tokens():
|
|
|
with torch.no_grad():
|
|
|
outputs = model.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=req.max_new_tokens,
|
|
|
temperature=req.temperature,
|
|
|
top_p=req.top_p,
|
|
|
repetition_penalty=req.repetition_penalty,
|
|
|
do_sample=True,
|
|
|
)
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
yield text
|
|
|
|
|
|
if req.stream:
|
|
|
return StreamingResponse(stream_tokens(), media_type="text/plain")
|
|
|
|
|
|
|
|
|
full_text = next(stream_tokens())
|
|
|
return {"response": full_text}
|
|
|
|
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
|
role: str
|
|
|
content: str
|
|
|
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
|
model: str | None = None
|
|
|
messages: list[ChatMessage]
|
|
|
temperature: float = 1.0
|
|
|
top_p: float = 1.0
|
|
|
max_tokens: int = 256
|
|
|
stream: bool = False
|
|
|
stop: list[str] | None = None
|
|
|
|
|
|
|
|
|
class CompletionRequest(BaseModel):
|
|
|
model: str | None = None
|
|
|
prompt: str
|
|
|
temperature: float = 1.0
|
|
|
top_p: float = 1.0
|
|
|
max_tokens: int = 256
|
|
|
stream: bool = False
|
|
|
stop: list[str] | None = None
|
|
|
|
|
|
|
|
|
def build_prompt_from_messages(messages: list[ChatMessage]) -> str:
|
|
|
|
|
|
try:
|
|
|
formatted = tokenizer.apply_chat_template(
|
|
|
[m.dict() for m in messages],
|
|
|
tokenize=False,
|
|
|
add_generation_prompt=True,
|
|
|
)
|
|
|
if isinstance(formatted, str) and formatted.strip():
|
|
|
return formatted
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
lines = []
|
|
|
for m in messages:
|
|
|
prefix = "Корисник:" if m.role == "user" else ("Асистент:" if m.role == "assistant" else "Систем:")
|
|
|
lines.append(f"{prefix} {m.content}")
|
|
|
lines.append("Асистент:")
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
def sse_pack(data: dict) -> str:
|
|
|
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
|
def completions(req: CompletionRequest):
|
|
|
ensure_model_loaded()
|
|
|
input_text = req.prompt
|
|
|
inputs = tokenizer(input_text, return_tensors="pt")
|
|
|
|
|
|
gen_kwargs = dict(
|
|
|
max_new_tokens=req.max_tokens,
|
|
|
temperature=req.temperature,
|
|
|
top_p=req.top_p,
|
|
|
do_sample=True,
|
|
|
)
|
|
|
|
|
|
request_id = f"cmpl-{uuid.uuid4().hex[:24]}"
|
|
|
model_name = os.getenv("MODEL_ID", "mk-llm")
|
|
|
created = int(time.time())
|
|
|
|
|
|
if req.stream:
|
|
|
def event_stream():
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
thread = threading.Thread(target=model.generate, kwargs={**inputs, **gen_kwargs, "streamer": streamer})
|
|
|
thread.start()
|
|
|
|
|
|
|
|
|
first = {
|
|
|
"id": request_id,
|
|
|
"object": "text_completion.chunk",
|
|
|
"created": created,
|
|
|
"model": model_name,
|
|
|
"choices": [{"text": "", "index": 0, "finish_reason": None}],
|
|
|
}
|
|
|
yield sse_pack(first)
|
|
|
for token_text in streamer:
|
|
|
chunk = {
|
|
|
"id": request_id,
|
|
|
"object": "text_completion.chunk",
|
|
|
"created": created,
|
|
|
"model": model_name,
|
|
|
"choices": [{"text": token_text, "index": 0, "finish_reason": None}],
|
|
|
}
|
|
|
yield sse_pack(chunk)
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model.generate(**inputs, **gen_kwargs)
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
prompt_tokens = inputs["input_ids"].shape[-1]
|
|
|
completion_tokens = tokenizer(text, return_tensors="pt")["input_ids"].shape[-1]
|
|
|
return {
|
|
|
"id": request_id,
|
|
|
"object": "text_completion",
|
|
|
"created": created,
|
|
|
"model": model_name,
|
|
|
"choices": [{"text": text, "index": 0, "finish_reason": "stop"}],
|
|
|
"usage": {
|
|
|
"prompt_tokens": int(prompt_tokens),
|
|
|
"completion_tokens": int(completion_tokens),
|
|
|
"total_tokens": int(prompt_tokens + completion_tokens),
|
|
|
},
|
|
|
}
|
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
|
def chat_completions(req: ChatCompletionRequest):
|
|
|
ensure_model_loaded()
|
|
|
prompt = build_prompt_from_messages(req.messages)
|
|
|
inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
|
gen_kwargs = dict(
|
|
|
max_new_tokens=req.max_tokens,
|
|
|
temperature=req.temperature,
|
|
|
top_p=req.top_p,
|
|
|
do_sample=True,
|
|
|
)
|
|
|
|
|
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
|
model_name = os.getenv("MODEL_ID", "mk-llm")
|
|
|
created = int(time.time())
|
|
|
|
|
|
if req.stream:
|
|
|
def event_stream():
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
thread = threading.Thread(target=model.generate, kwargs={**inputs, **gen_kwargs, "streamer": streamer})
|
|
|
thread.start()
|
|
|
|
|
|
|
|
|
first_chunk = {
|
|
|
"id": request_id,
|
|
|
"object": "chat.completion.chunk",
|
|
|
"created": created,
|
|
|
"model": model_name,
|
|
|
"choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}],
|
|
|
}
|
|
|
yield sse_pack(first_chunk)
|
|
|
|
|
|
for token_text in streamer:
|
|
|
chunk = {
|
|
|
"id": request_id,
|
|
|
"object": "chat.completion.chunk",
|
|
|
"created": created,
|
|
|
"model": model_name,
|
|
|
"choices": [{"delta": {"content": token_text}, "index": 0, "finish_reason": None}],
|
|
|
}
|
|
|
yield sse_pack(chunk)
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = model.generate(**inputs, **gen_kwargs)
|
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
prompt_tokens = inputs["input_ids"].shape[-1]
|
|
|
completion_tokens = tokenizer(text, return_tensors="pt")["input_ids"].shape[-1]
|
|
|
return {
|
|
|
"id": request_id,
|
|
|
"object": "chat.completion",
|
|
|
"created": created,
|
|
|
"model": model_name,
|
|
|
"choices": [
|
|
|
{
|
|
|
"index": 0,
|
|
|
"message": {"role": "assistant", "content": text},
|
|
|
"finish_reason": "stop",
|
|
|
}
|
|
|
],
|
|
|
"usage": {
|
|
|
"prompt_tokens": int(prompt_tokens),
|
|
|
"completion_tokens": int(completion_tokens),
|
|
|
"total_tokens": int(prompt_tokens + completion_tokens),
|
|
|
},
|
|
|
}
|
|
|
|
|
|
@app.get("/v1/models")
|
|
|
def list_models():
|
|
|
created = int(time.time())
|
|
|
return {
|
|
|
"object": "list",
|
|
|
"data": [
|
|
|
{
|
|
|
"id": "mk-llm",
|
|
|
"object": "model",
|
|
|
"created": created,
|
|
|
"owned_by": "community",
|
|
|
}
|
|
|
],
|
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
uvicorn.run(app, host=HOST, port=PORT)
|
|
|
|