caarleexx commited on
Commit
0849add
Β·
verified Β·
1 Parent(s): c11c818

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +65 -107
app/main.py CHANGED
@@ -3,65 +3,61 @@ from typing import List, Optional
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel, Field
6
- import ctranslate2
7
- from transformers import AutoTokenizer
8
 
9
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
10
  logger = logging.getLogger(__name__)
11
 
12
- # ─── Config via ENV ───────────────────────────────────────────────────────────
13
- CT2_MODEL_DIR = os.getenv("CT2_MODEL_DIR", "/app/ct2_model")
14
- TOKENIZER_DIR = os.getenv("TOKENIZER_DIR", "/app/tokenizer")
15
- CT2_MODEL_ID = os.getenv("CT2_MODEL_ID", "limcheekin/flan-t5-xxl-ct2")
16
- TOKENIZER_ID = os.getenv("TOKENIZER_ID", "google/flan-t5-xxl")
17
  MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "1024"))
18
- MAX_INPUT_LEN = int(os.getenv("MAX_INPUT_LEN", "512")) # 16384 se usar long-t5
19
- INTER_THREADS = int(os.getenv("INTER_THREADS", "2"))
20
- INTRA_THREADS = int(os.getenv("INTRA_THREADS", "2"))
21
 
22
- # ─── App ─────────────────────────────────────────────────────────────────────
23
- app = FastAPI(title="T2T OpenAI-Compatible API", version="2.0.0")
24
  app.add_middleware(
25
  CORSMiddleware,
26
- allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
27
  )
28
 
29
- _translator: ctranslate2.Translator = None
30
- _tokenizer: AutoTokenizer = None
31
 
32
- # ─── Startup ─────────────────────────────────────────────────────────────────
33
  @app.on_event("startup")
34
  def load_model():
35
- global _translator, _tokenizer
36
- logger.info(f"⏳ Carregando CT2 model: {CT2_MODEL_DIR}")
37
- _translator = ctranslate2.Translator(
38
- CT2_MODEL_DIR,
39
- device = "cpu",
40
- compute_type = "int8",
41
- inter_threads = INTER_THREADS,
42
- intra_threads = INTRA_THREADS,
43
  )
44
- logger.info(f"⏳ Carregando tokenizer: {TOKENIZER_DIR}")
45
- _tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
46
- logger.info("βœ… Model pronto!")
 
 
 
 
 
47
 
48
- # ─── Schemas (espelho exato do OpenAI) ───────────────────────────────────────
49
  class Message(BaseModel):
50
- role: str # "system" | "user" | "assistant"
51
  content: str
52
 
53
  class ResponseFormat(BaseModel):
54
  type: str = "text"
55
 
56
  class ChatCompletionRequest(BaseModel):
57
- model: str = Field(default=CT2_MODEL_ID)
58
  messages: List[Message]
59
- temperature: float = 0.7
60
- top_p: float = 0.9
61
- max_completion_tokens: Optional[int] = None
62
  max_tokens: Optional[int] = None
63
  response_format: Optional[ResponseFormat] = None
64
- stream: bool = False
65
 
66
  class Config:
67
  populate_by_name = True
@@ -71,29 +67,25 @@ class ChoiceMessage(BaseModel):
71
  content: str
72
 
73
  class Choice(BaseModel):
74
- index: int
75
- message: ChoiceMessage
76
  finish_reason: str = "stop"
77
 
78
  class Usage(BaseModel):
79
- prompt_tokens: int
80
  completion_tokens: int
81
- total_tokens: int
82
 
83
  class ChatCompletionResponse(BaseModel):
84
- id: str
85
- object: str = "chat.completion"
86
  created: int
87
- model: str
88
  choices: List[Choice]
89
- usage: Usage
90
 
91
- # ─── Helpers ─────────────────────────────────────────────────────────────────
92
  def messages_to_prompt(messages: List[Message]) -> str:
93
- """
94
- Converte lista de mensagens em prompt ΓΊnico para modelos seq2seq.
95
- Preserva contexto system + histΓ³rico de conversa.
96
- """
97
  parts = []
98
  for m in messages:
99
  if m.role == "system":
@@ -102,95 +94,61 @@ def messages_to_prompt(messages: List[Message]) -> str:
102
  parts.append(f"User: {m.content}")
103
  elif m.role == "assistant":
104
  parts.append(f"Assistant: {m.content}")
105
- return "\n".join(parts)
106
 
107
- def count_tokens(text: str) -> int:
108
  return len(_tokenizer(text, add_special_tokens=False)["input_ids"])
109
 
110
- # ─── Endpoints ───────────────────────────────────────────────────────────────
111
  @app.get("/")
112
  def root():
113
- return {
114
- "status": "ok",
115
- "model": CT2_MODEL_ID,
116
- "max_input_tokens": MAX_INPUT_LEN,
117
- "max_output_tokens": MAX_NEW_TOKENS,
118
- }
119
 
120
  @app.get("/health")
121
  def health():
122
- return {
123
- "status": "healthy",
124
- "model": CT2_MODEL_ID,
125
- "model_loaded": _translator is not None,
126
- }
127
 
128
  @app.get("/v1/models")
129
  def list_models():
130
- return {
131
- "object": "list",
132
- "data": [{
133
- "id": CT2_MODEL_ID,
134
- "object": "model",
135
- "owned_by": "huggingface",
136
- }],
137
- }
138
 
139
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
140
  def chat_completions(req: ChatCompletionRequest):
141
  if req.stream:
142
- raise HTTPException(501, "Streaming nΓ£o suportado ainda.")
143
-
144
- if _translator is None or _tokenizer is None:
145
- raise HTTPException(503, "Modelo ainda nΓ£o carregado.")
146
 
147
  max_tokens = req.max_completion_tokens or req.max_tokens or MAX_NEW_TOKENS
148
  prompt = messages_to_prompt(req.messages)
149
-
150
- # Tokeniza com truncation para respeitar janela do modelo
151
- encoded = _tokenizer(
152
- prompt,
153
- return_tensors = None,
154
- truncation = True,
155
- max_length = MAX_INPUT_LEN,
156
- add_special_tokens = True,
157
- )
158
- input_tokens = [_tokenizer.convert_ids_to_tokens(encoded["input_ids"])]
159
-
160
- do_sample = req.temperature > 0.05
161
 
162
  try:
163
- results = _translator.translate_batch(
164
- input_tokens,
165
- max_decoding_length = max_tokens,
166
- min_decoding_length = 1,
167
- beam_size = 1 if do_sample else 4,
168
- sampling_temperature = float(req.temperature) if do_sample else 1.0,
169
- sampling_topk = 50 if do_sample else 1,
170
- sampling_topp = float(req.top_p) if do_sample else 1.0,
171
- repetition_penalty = 1.2, # evita repetiΓ§Γ΅es em textos longos
172
  )
173
  except Exception as e:
174
  logger.error(f"Inference error: {e}")
175
- raise HTTPException(500, f"Erro na inferΓͺncia: {e}")
176
-
177
- output_tokens = results[0].hypotheses[0]
178
- generated = _tokenizer.convert_tokens_to_string(output_tokens).strip()
179
 
180
- p_tok = count_tokens(prompt)
181
- c_tok = count_tokens(generated)
 
182
 
183
  return ChatCompletionResponse(
184
  id = f"chatcmpl-{uuid.uuid4().hex[:12]}",
185
  created = int(time.time()),
186
  model = req.model,
187
- choices = [
188
- Choice(
189
- index = 0,
190
- message = ChoiceMessage(content=generated),
191
- )
192
- ],
193
- usage = Usage(
194
  prompt_tokens = p_tok,
195
  completion_tokens = c_tok,
196
  total_tokens = p_tok + c_tok,
 
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel, Field
6
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
7
+ import torch
8
 
9
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
10
  logger = logging.getLogger(__name__)
11
 
12
+ MODEL_ID = os.getenv("MODEL_ID", "google/flan-t5-large")
 
 
 
 
13
  MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "1024"))
14
+ MAX_INPUT_LEN = int(os.getenv("MAX_INPUT_LEN", "512"))
 
 
15
 
16
+ app = FastAPI(title="T2T OpenAI-Compatible API", version="3.0.0")
 
17
  app.add_middleware(
18
  CORSMiddleware,
19
+ allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
20
  )
21
 
22
+ _pipe = None
23
+ _tokenizer = None
24
 
 
25
  @app.on_event("startup")
26
  def load_model():
27
+ global _pipe, _tokenizer
28
+ logger.info(f"⏳ Carregando {MODEL_ID} …")
29
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
30
+ model = AutoModelForSeq2SeqLM.from_pretrained(
31
+ MODEL_ID,
32
+ torch_dtype = torch.float32,
33
+ low_cpu_mem_usage = True,
 
34
  )
35
+ model.eval()
36
+ _pipe = pipeline(
37
+ "text2text-generation",
38
+ model = model,
39
+ tokenizer = _tokenizer,
40
+ device = -1, # forΓ§a CPU
41
+ )
42
+ logger.info(f"βœ… {MODEL_ID} pronto!")
43
 
44
+ # ── Schemas (OpenAI-compatible) ───────────────────────────────────────────
45
  class Message(BaseModel):
46
+ role: str
47
  content: str
48
 
49
  class ResponseFormat(BaseModel):
50
  type: str = "text"
51
 
52
  class ChatCompletionRequest(BaseModel):
53
+ model: str = Field(default=MODEL_ID)
54
  messages: List[Message]
55
+ temperature: float = 0.7
56
+ top_p: float = 0.9
57
+ max_completion_tokens: Optional[int] = None
58
  max_tokens: Optional[int] = None
59
  response_format: Optional[ResponseFormat] = None
60
+ stream: bool = False
61
 
62
  class Config:
63
  populate_by_name = True
 
67
  content: str
68
 
69
  class Choice(BaseModel):
70
+ index: int
71
+ message: ChoiceMessage
72
  finish_reason: str = "stop"
73
 
74
  class Usage(BaseModel):
75
+ prompt_tokens: int
76
  completion_tokens: int
77
+ total_tokens: int
78
 
79
  class ChatCompletionResponse(BaseModel):
80
+ id: str
81
+ object: str = "chat.completion"
82
  created: int
83
+ model: str
84
  choices: List[Choice]
85
+ usage: Usage
86
 
87
+ # ── Helpers ───────────────────────────────────────────────────────────────
88
  def messages_to_prompt(messages: List[Message]) -> str:
 
 
 
 
89
  parts = []
90
  for m in messages:
91
  if m.role == "system":
 
94
  parts.append(f"User: {m.content}")
95
  elif m.role == "assistant":
96
  parts.append(f"Assistant: {m.content}")
97
+ return " ".join(parts)
98
 
99
+ def token_count(text: str) -> int:
100
  return len(_tokenizer(text, add_special_tokens=False)["input_ids"])
101
 
102
+ # ── Endpoints ─────────────────────────────────────────────────────────────
103
  @app.get("/")
104
  def root():
105
+ return {"status": "ok", "model": MODEL_ID}
 
 
 
 
 
106
 
107
  @app.get("/health")
108
  def health():
109
+ return {"status": "healthy", "model": MODEL_ID, "ready": _pipe is not None}
 
 
 
 
110
 
111
  @app.get("/v1/models")
112
  def list_models():
113
+ return {"object": "list", "data": [
114
+ {"id": MODEL_ID, "object": "model", "owned_by": "huggingface"}
115
+ ]}
 
 
 
 
 
116
 
117
  @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
118
  def chat_completions(req: ChatCompletionRequest):
119
  if req.stream:
120
+ raise HTTPException(501, "Streaming nΓ£o suportado.")
121
+ if _pipe is None:
122
+ raise HTTPException(503, "Modelo nΓ£o carregado.")
 
123
 
124
  max_tokens = req.max_completion_tokens or req.max_tokens or MAX_NEW_TOKENS
125
  prompt = messages_to_prompt(req.messages)
126
+ do_sample = req.temperature > 0.05
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  try:
129
+ output = _pipe(
130
+ prompt,
131
+ max_new_tokens = max_tokens,
132
+ truncation = True,
133
+ temperature = float(req.temperature) if do_sample else 1.0,
134
+ top_p = float(req.top_p) if do_sample else 1.0,
135
+ do_sample = do_sample,
136
+ repetition_penalty = 1.2,
 
137
  )
138
  except Exception as e:
139
  logger.error(f"Inference error: {e}")
140
+ raise HTTPException(500, str(e))
 
 
 
141
 
142
+ text = output[0]["generated_text"].strip()
143
+ p_tok = token_count(prompt)
144
+ c_tok = token_count(text)
145
 
146
  return ChatCompletionResponse(
147
  id = f"chatcmpl-{uuid.uuid4().hex[:12]}",
148
  created = int(time.time()),
149
  model = req.model,
150
+ choices = [Choice(index=0, message=ChoiceMessage(content=text))],
151
+ usage = Usage(
 
 
 
 
 
152
  prompt_tokens = p_tok,
153
  completion_tokens = c_tok,
154
  total_tokens = p_tok + c_tok,