import os from typing import List, Literal, Optional import torch from fastapi import FastAPI from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer # ---------------------------- # Model config (matches demo) # ---------------------------- MODEL_NAME = os.getenv("MODEL_NAME", "MBZUAI-Paris/Nile-Chat-12B") MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024")) app = FastAPI(title="Nile-Chat-12B FastAPI") tokenizer = None model = None # ---------------------------- # Request schemas # ---------------------------- Role = Literal["system", "user", "assistant"] class ChatMessage(BaseModel): role: Role content: str class GenerateRequest(BaseModel): # نفس مفهوم Gradio: history + message # لكن هنا هنوحّدها: messages كاملة، وآخر user message هي الطلب الحالي messages: List[ChatMessage] = Field(..., description="Conversation messages in OpenAI-like format") max_new_tokens: int = Field(DEFAULT_MAX_NEW_TOKENS, ge=1, le=MAX_MAX_NEW_TOKENS) do_sample: bool = True temperature: float = Field(0.6, ge=0.0, le=4.0) top_p: float = Field(0.9, ge=0.05, le=1.0) top_k: int = Field(50, ge=1, le=1000) repetition_penalty: float = Field(1.1, ge=1.0, le=2.0) class GenerateResponse(BaseModel): response: str trimmed: bool = False model: str = MODEL_NAME # ---------------------------- # Startup # ---------------------------- @app.on_event("startup") def startup_event(): global tokenizer, model tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # نفس منطق الديمو: bfloat16 + device_map auto dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", torch_dtype=dtype, ) model.eval() print("Model ready") # ✅ زي ما طلبت @app.get("/health") def health(): return {"status": "ok", "model": MODEL_NAME} # ---------------------------- # Core generation # ---------------------------- @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): global tokenizer, model if not req.messages: return GenerateResponse(response="Error: messages is empty", trimmed=False) # Nile-Chat demo بيستخدم apply_chat_template على conversation كلها conversation = [m.model_dump() for m in req.messages] # Build input_ids exactly like the Gradio demo input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ) trimmed = False if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] trimmed = True input_ids = input_ids.to(model.device) # Logging last_user = next((m.content for m in reversed(req.messages) if m.role == "user"), "") print("\n=== Incoming Request ===") print("MODEL:", MODEL_NAME) print("LAST USER:", last_user) print("trimmed_input:", trimmed) print("input_tokens:", int(input_ids.shape[1])) # Generate (non-streaming API response) with torch.no_grad(): out = model.generate( input_ids=input_ids, max_new_tokens=req.max_new_tokens, do_sample=req.do_sample, top_p=req.top_p, top_k=req.top_k, temperature=req.temperature, num_beams=1, repetition_penalty=req.repetition_penalty, ) # Decode only new tokens (same idea as your Qwen API) new_tokens = out[0, input_ids.shape[-1]:] response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() print("\n=== Model Response ===") print(response_text) print("======================\n") return GenerateResponse(response=response_text, trimmed=trimmed, model=MODEL_NAME)