Mr-Help commited on
Commit
de5fced
ยท
verified ยท
1 Parent(s): aa32412

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +82 -34
main.py CHANGED
@@ -1,45 +1,74 @@
1
  import os
 
 
2
  import torch
3
  from fastapi import FastAPI
4
- from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-1.5B-Instruct")
 
 
 
 
 
 
 
8
 
9
- app = FastAPI(title="Qwen FastAPI")
10
 
11
  tokenizer = None
12
  model = None
13
 
14
 
 
 
 
 
 
 
 
 
 
15
  class GenerateRequest(BaseModel):
16
- system_prompt: str
17
- user_prompt: str
18
- max_new_tokens: int = 400
19
- temperature: float = 0.7
20
- top_p: float = 0.9
21
  do_sample: bool = True
 
 
 
 
 
22
 
 
 
 
 
23
 
 
 
 
 
24
  @app.on_event("startup")
25
  def startup_event():
26
  global tokenizer, model
27
 
28
- # Load tokenizer
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
30
 
31
- # dtype: bfloat16 on CUDA, float32 on CPU
32
- has_cuda = torch.cuda.is_available()
33
- dtype = torch.bfloat16 if has_cuda else torch.float32
34
 
35
- # Load model (auto device placement)
36
  model = AutoModelForCausalLM.from_pretrained(
37
  MODEL_NAME,
 
38
  torch_dtype=dtype,
39
- device_map="auto"
40
  )
 
41
 
42
- print("Model ready") # โœ… ู…ุทู„ูˆุจ ู…ู†ูƒ
43
 
44
 
45
  @app.get("/health")
@@ -47,41 +76,60 @@ def health():
47
  return {"status": "ok", "model": MODEL_NAME}
48
 
49
 
50
- @app.post("/generate")
 
 
 
51
  def generate(req: GenerateRequest):
52
  global tokenizer, model
53
 
54
- messages = [
55
- {"role": "system", "content": req.system_prompt},
56
- {"role": "user", "content": req.user_prompt}
57
- ]
 
58
 
59
- text = tokenizer.apply_chat_template(
60
- messages,
61
- tokenize=False,
62
- add_generation_prompt=True
 
63
  )
64
 
65
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
 
 
66
 
 
 
67
  print("\n=== Incoming Request ===")
68
- print("SYSTEM:", req.system_prompt)
69
- print("USER:", req.user_prompt)
 
 
70
 
 
71
  with torch.no_grad():
72
- generated_ids = model.generate(
73
- **model_inputs,
74
  max_new_tokens=req.max_new_tokens,
75
  do_sample=req.do_sample,
76
- temperature=req.temperature,
77
  top_p=req.top_p,
 
 
 
 
78
  )
79
 
80
- new_tokens = generated_ids[0, model_inputs["input_ids"].shape[-1]:]
81
- response = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
 
82
 
83
  print("\n=== Model Response ===")
84
- print(response)
85
  print("======================\n")
86
 
87
- return {"response": response}
 
1
  import os
2
+ from typing import List, Literal, Optional
3
+
4
  import torch
5
  from fastapi import FastAPI
6
+ from pydantic import BaseModel, Field
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
+ # ----------------------------
10
+ # Model config (matches demo)
11
+ # ----------------------------
12
+ MODEL_NAME = os.getenv("MODEL_NAME", "MBZUAI-Paris/Nile-Chat-12B")
13
+
14
+ MAX_MAX_NEW_TOKENS = 2048
15
+ DEFAULT_MAX_NEW_TOKENS = 1024
16
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024"))
17
 
18
+ app = FastAPI(title="Nile-Chat-12B FastAPI")
19
 
20
  tokenizer = None
21
  model = None
22
 
23
 
24
+ # ----------------------------
25
+ # Request schemas
26
+ # ----------------------------
27
+ Role = Literal["system", "user", "assistant"]
28
+
29
+ class ChatMessage(BaseModel):
30
+ role: Role
31
+ content: str
32
+
33
  class GenerateRequest(BaseModel):
34
+ # ู†ูุณ ู…ูู‡ูˆู… Gradio: history + message
35
+ # ู„ูƒู† ู‡ู†ุง ู‡ู†ูˆุญู‘ุฏู‡ุง: messages ูƒุงู…ู„ุฉุŒ ูˆุขุฎุฑ user message ู‡ูŠ ุงู„ุทู„ุจ ุงู„ุญุงู„ูŠ
36
+ messages: List[ChatMessage] = Field(..., description="Conversation messages in OpenAI-like format")
37
+
38
+ max_new_tokens: int = Field(DEFAULT_MAX_NEW_TOKENS, ge=1, le=MAX_MAX_NEW_TOKENS)
39
  do_sample: bool = True
40
+ temperature: float = Field(0.6, ge=0.0, le=4.0)
41
+ top_p: float = Field(0.9, ge=0.05, le=1.0)
42
+ top_k: int = Field(50, ge=1, le=1000)
43
+ repetition_penalty: float = Field(1.1, ge=1.0, le=2.0)
44
+
45
 
46
+ class GenerateResponse(BaseModel):
47
+ response: str
48
+ trimmed: bool = False
49
+ model: str = MODEL_NAME
50
 
51
+
52
+ # ----------------------------
53
+ # Startup
54
+ # ----------------------------
55
  @app.on_event("startup")
56
  def startup_event():
57
  global tokenizer, model
58
 
 
59
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
60
 
61
+ # ู†ูุณ ู…ู†ุทู‚ ุงู„ุฏูŠู…ูˆ: bfloat16 + device_map auto
62
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
63
 
 
64
  model = AutoModelForCausalLM.from_pretrained(
65
  MODEL_NAME,
66
+ device_map="auto",
67
  torch_dtype=dtype,
 
68
  )
69
+ model.eval()
70
 
71
+ print("Model ready") # โœ… ุฒูŠ ู…ุง ุทู„ุจุช
72
 
73
 
74
  @app.get("/health")
 
76
  return {"status": "ok", "model": MODEL_NAME}
77
 
78
 
79
+ # ----------------------------
80
+ # Core generation
81
+ # ----------------------------
82
+ @app.post("/generate", response_model=GenerateResponse)
83
  def generate(req: GenerateRequest):
84
  global tokenizer, model
85
 
86
+ if not req.messages:
87
+ return GenerateResponse(response="Error: messages is empty", trimmed=False)
88
+
89
+ # Nile-Chat demo ุจูŠุณุชุฎุฏู… apply_chat_template ุนู„ู‰ conversation ูƒู„ู‡ุง
90
+ conversation = [m.model_dump() for m in req.messages]
91
 
92
+ # Build input_ids exactly like the Gradio demo
93
+ input_ids = tokenizer.apply_chat_template(
94
+ conversation,
95
+ add_generation_prompt=True,
96
+ return_tensors="pt"
97
  )
98
 
99
+ trimmed = False
100
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
101
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
102
+ trimmed = True
103
+
104
+ input_ids = input_ids.to(model.device)
105
 
106
+ # Logging
107
+ last_user = next((m.content for m in reversed(req.messages) if m.role == "user"), "")
108
  print("\n=== Incoming Request ===")
109
+ print("MODEL:", MODEL_NAME)
110
+ print("LAST USER:", last_user)
111
+ print("trimmed_input:", trimmed)
112
+ print("input_tokens:", int(input_ids.shape[1]))
113
 
114
+ # Generate (non-streaming API response)
115
  with torch.no_grad():
116
+ out = model.generate(
117
+ input_ids=input_ids,
118
  max_new_tokens=req.max_new_tokens,
119
  do_sample=req.do_sample,
 
120
  top_p=req.top_p,
121
+ top_k=req.top_k,
122
+ temperature=req.temperature,
123
+ num_beams=1,
124
+ repetition_penalty=req.repetition_penalty,
125
  )
126
 
127
+ # Decode only new tokens (same idea as your Qwen API)
128
+ new_tokens = out[0, input_ids.shape[-1]:]
129
+ response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
130
 
131
  print("\n=== Model Response ===")
132
+ print(response_text)
133
  print("======================\n")
134
 
135
+ return GenerateResponse(response=response_text, trimmed=trimmed, model=MODEL_NAME)