File size: 4,038 Bytes
6867f65 de5fced 6cef2a4 6867f65 de5fced 6cef2a4 21725a9 de5fced 16cb566 de5fced 6867f65 de5fced 6867f65 de5fced 6867f65 de5fced 6867f65 de5fced 6867f65 de5fced 6867f65 6b6dd1e 6cef2a4 b6f476e de5fced b6f476e 6cef2a4 de5fced 6cef2a4 de5fced 6cef2a4 de5fced 6867f65 de5fced 6867f65 de5fced 6cef2a4 de5fced 6cef2a4 de5fced 6cef2a4 de5fced 6867f65 de5fced 6867f65 de5fced 6cef2a4 de5fced 6867f65 de5fced 6cef2a4 de5fced 6cef2a4 6867f65 de5fced 6867f65 6cef2a4 de5fced |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
from typing import List, Literal, Optional
import torch
from fastapi import FastAPI
from pydantic import BaseModel, Field
from transformers import AutoModelForCausalLM, AutoTokenizer
# ----------------------------
# Model config (matches demo)
# ----------------------------
MODEL_NAME = os.getenv("MODEL_NAME", "MBZUAI-Paris/Nile-Chat-12B")
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024"))
app = FastAPI(title="Nile-Chat-12B FastAPI")
tokenizer = None
model = None
# ----------------------------
# Request schemas
# ----------------------------
Role = Literal["system", "user", "assistant"]
class ChatMessage(BaseModel):
role: Role
content: str
class GenerateRequest(BaseModel):
# ููุณ ู
ูููู
Gradio: history + message
# ููู ููุง ูููุญูุฏูุง: messages ูุงู
ูุฉุ ูุขุฎุฑ user message ูู ุงูุทูุจ ุงูุญุงูู
messages: List[ChatMessage] = Field(..., description="Conversation messages in OpenAI-like format")
max_new_tokens: int = Field(DEFAULT_MAX_NEW_TOKENS, ge=1, le=MAX_MAX_NEW_TOKENS)
do_sample: bool = True
temperature: float = Field(0.6, ge=0.0, le=4.0)
top_p: float = Field(0.9, ge=0.05, le=1.0)
top_k: int = Field(50, ge=1, le=1000)
repetition_penalty: float = Field(1.1, ge=1.0, le=2.0)
class GenerateResponse(BaseModel):
response: str
trimmed: bool = False
model: str = MODEL_NAME
# ----------------------------
# Startup
# ----------------------------
@app.on_event("startup")
def startup_event():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# ููุณ ู
ูุทู ุงูุฏูู
ู: bfloat16 + device_map auto
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=dtype,
)
model.eval()
print("Model ready") # โ
ุฒู ู
ุง ุทูุจุช
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME}
# ----------------------------
# Core generation
# ----------------------------
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
global tokenizer, model
if not req.messages:
return GenerateResponse(response="Error: messages is empty", trimmed=False)
# Nile-Chat demo ุจูุณุชุฎุฏู
apply_chat_template ุนูู conversation ูููุง
conversation = [m.model_dump() for m in req.messages]
# Build input_ids exactly like the Gradio demo
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
)
trimmed = False
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
trimmed = True
input_ids = input_ids.to(model.device)
# Logging
last_user = next((m.content for m in reversed(req.messages) if m.role == "user"), "")
print("\n=== Incoming Request ===")
print("MODEL:", MODEL_NAME)
print("LAST USER:", last_user)
print("trimmed_input:", trimmed)
print("input_tokens:", int(input_ids.shape[1]))
# Generate (non-streaming API response)
with torch.no_grad():
out = model.generate(
input_ids=input_ids,
max_new_tokens=req.max_new_tokens,
do_sample=req.do_sample,
top_p=req.top_p,
top_k=req.top_k,
temperature=req.temperature,
num_beams=1,
repetition_penalty=req.repetition_penalty,
)
# Decode only new tokens (same idea as your Qwen API)
new_tokens = out[0, input_ids.shape[-1]:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
print("\n=== Model Response ===")
print(response_text)
print("======================\n")
return GenerateResponse(response=response_text, trimmed=trimmed, model=MODEL_NAME)
|