jnjj commited on
Commit
499dbc9
·
verified ·
1 Parent(s): 77efeb9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1491 -0
app.py ADDED
@@ -0,0 +1,1491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import json
4
+ import random
5
+ import torch
6
+ import asyncio
7
+ import logging
8
+ import time
9
+ from typing import List, Dict, Any, Optional, Union, AsyncGenerator, Tuple
10
+ from fastapi import FastAPI, HTTPException, Query, Request, Depends, status
11
+ from fastapi.responses import StreamingResponse, PlainTextResponse, HTMLResponse, JSONResponse
12
+ from fastapi.security import APIKeyHeader
13
+ from pydantic import BaseModel, Field, ValidationError, validator
14
+ from transformers import (
15
+ AutoConfig, AutoModelForCausalLM, AutoTokenizer,
16
+ GenerationConfig, LogitsProcessorList,
17
+ MinLengthLogitsProcessor, MaxLengthCriteria,
18
+ StoppingCriteriaList, StoppingCriteria
19
+ )
20
+ import uvicorn
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ import math
23
+ import torch.nn.functional as F
24
+ import copy
25
+
26
+ app = FastAPI(title="Chatbot Profesional API", version="1.0.0")
27
+
28
+ class StopSequenceCriteria(StoppingCriteria):
29
+ def __init__(self, stop_sequences: List[str], tokenizer: AutoTokenizer):
30
+ self.tokenizer = tokenizer
31
+ self.stop_sequences_text = []
32
+ self.stop_sequence_ids = []
33
+ for seq in stop_sequences:
34
+ if seq:
35
+ encoded_ids = tokenizer.encode(seq, add_special_tokens=False)
36
+ decoded_text = tokenizer.decode(encoded_ids, skip_special_tokens=True)
37
+ if decoded_text:
38
+ self.stop_sequences_text.append(decoded_text)
39
+ self.stop_sequence_ids.append(encoded_ids)
40
+
41
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
42
+ if not self.stop_sequence_ids:
43
+ return False
44
+
45
+ input_ids_list = input_ids[0].tolist()
46
+
47
+ for stop_seq_ids in self.stop_sequence_ids:
48
+ stop_len = len(stop_seq_ids)
49
+ if len(input_ids_list) >= stop_len:
50
+ if input_ids_list[-stop_len:] == stop_seq_ids:
51
+ return True
52
+
53
+ check_tail_len = 50
54
+ if self.stop_sequence_ids:
55
+ max_stop_seq_token_len = max((len(seq) for seq in self.stop_sequence_ids), default=0)
56
+ check_tail_len = max(check_tail_len, max_stop_seq_token_len + 10)
57
+
58
+ tail_ids = input_ids_list[-min(check_tail_len, len(input_ids_list)):]
59
+ tail_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
60
+
61
+ for stop_seq_text in self.stop_sequences_text:
62
+ if stop_seq_text and stop_seq_text in tail_text:
63
+ return True
64
+
65
+ return False
66
+
67
+ logging.getLogger("uvicorn").handlers.clear()
68
+ logging.getLogger("uvicorn.error").handlers.clear()
69
+ logging.getLogger("uvicorn.access").handlers.clear()
70
+ logging.getLogger("uvicorn").propagate = False
71
+ logging.getLogger("uvicorn.error").propagate = False
72
+ logging.getLogger("uvicorn.access").propagate = False
73
+ logging.getLogger("uvicorn").setLevel(logging.CRITICAL)
74
+ logging.getLogger("uvicorn.error").setLevel(logging.CRITICAL)
75
+ logging.getLogger("uvicorn.access").setLevel(logging.CRITICAL)
76
+ logging.getLogger("fastapi").setLevel(logging.CRITICAL)
77
+ logging.getLogger("transformers").setLevel(logging.CRITICAL)
78
+ logging.getLogger().handlers.clear()
79
+ logging.getLogger().addHandler(logging.NullHandler())
80
+
81
+ DEFAULT_MODEL_NAME = "hghghgkskdmskdms/xddd"
82
+ MODEL_NAME = os.environ.get("MODEL_NAME", DEFAULT_MODEL_NAME)
83
+ SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", "Eres un asistente profesional y servicial.")
84
+
85
+ try:
86
+ MAX_CONTEXT_TOKENS = int(os.environ.get("MAX_CONTEXT_TOKENS", 1024))
87
+ if MAX_CONTEXT_TOKENS <= 0:
88
+ raise ValueError("MAX_CONTEXT_TOKENS must be positive.")
89
+ except (ValueError, TypeError) as e:
90
+ logging.error(f"Invalid MAX_CONTEXT_TOKENS environment variable: {os.environ.get('MAX_CONTEXT_TOKENS')}. Using default 1024. Error: {e}")
91
+ MAX_CONTEXT_TOKENS = 1024
92
+
93
+ try:
94
+ MAX_GENERATION_TOKENS = int(os.environ.get("MAX_GENERATION_TOKENS", 512))
95
+ if MAX_GENERATION_TOKENS <= 0:
96
+ raise ValueError("MAX_GENERATION_TOKENS must be positive.")
97
+ except (ValueError, TypeError) as e:
98
+ logging.error(f"Invalid MAX_GENERATION_TOKENS environment variable: {os.environ.get('MAX_GENERATION_TOKENS')}. Using default 512. Error: {e}")
99
+ MAX_GENERATION_TOKENS = 512
100
+
101
+ try:
102
+ MAX_CONCURRENT_GENERATIONS = int(os.environ.get("MAX_CONCURRENT_GENERATIONS", 4))
103
+ if MAX_CONCURRENT_GENERATIONS <= 0:
104
+ raise ValueError("MAX_CONCURRENT_GENERATIONS must be positive.")
105
+ except (ValueError, TypeError) as e:
106
+ logging.error(f"Invalid MAX_CONCURRENT_GENERATIONS environment variable: {os.environ.get('MAX_CONCURRENT_GENERATIONS')}. Using default 4. Error: {e}")
107
+ MAX_CONCURRENT_GENERATIONS = 4
108
+
109
+ TRUST_REMOTE_CODE_ENV = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true"
110
+ TRUST_REMOTE_CODE = TRUST_REMOTE_CODE_ENV or (MODEL_NAME == DEFAULT_MODEL_NAME)
111
+ ENABLE_FLASH_ATTENTION_2 = os.environ.get("ENABLE_FLASH_ATTENTION_2", "false").lower() == "true"
112
+ TORCH_DTYPE_STR = os.environ.get("TORCH_DTYPE", "float32")
113
+ TORCH_DTYPE = getattr(torch, TORCH_DTYPE_STR.lower(), torch.float32)
114
+ if TORCH_DTYPE != torch.float32:
115
+ logging.warning(f"Requested dtype {TORCH_DTYPE_STR} might not be fully performant on CPU. Using float32.")
116
+ TORCH_DTYPE = torch.float32
117
+
118
+ API_KEY = os.environ.get("API_KEY")
119
+
120
+ global_model = None
121
+ global_tokenizer = None
122
+ global_tokens: Dict[str, Optional[int]] = {}
123
+ executor = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_GENERATIONS)
124
+ generation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_GENERATIONS)
125
+
126
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
127
+
128
+ async def get_api_key(api_key: str = Depends(api_key_header)):
129
+ if API_KEY is None:
130
+ return
131
+ if api_key is None or api_key != API_KEY:
132
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing API Key")
133
+ return api_key
134
+
135
+ class GenerateRequest(BaseModel):
136
+ input_text: str = Field(..., description="The input text from the user.", examples=["Hola, ¿cómo estás?"])
137
+ history: Optional[List[Dict[str, str]]] = Field(None, description="Conversation history.", examples=[[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is Paris."}]])
138
+ stream: bool = Field(True, description="Whether to stream the response.")
139
+ temperature: float = Field(1.0, ge=0.0, le=2.0, description="Controls the randomness.")
140
+ top_k: int = Field(50, ge=0, description="Top-k filtering.")
141
+ top_p: float = Field(1.0, ge=0.0, le=1.0, description="Top-p (nucleus) filtering.")
142
+ repetition_penalty: float = Field(1.0, ge=0.0, description="Repetition penalty.")
143
+ frequency_penalty: float = Field(0.0, ge=0.0, description="Frequency penalty.")
144
+ presence_penalty: float = Field(0.0, ge=0.0, description="Presence penalty.")
145
+ num_beams: int = Field(1, ge=1, description="Number of beams for beam search.")
146
+ length_penalty: float = Field(1.0, ge=0.0, description="Length penalty.")
147
+ no_repeat_ngram_size: int = Field(0, ge=0, description="No repeat ngram size.")
148
+ early_stopping: bool = Field(False, description="Early stopping for beam search.")
149
+ do_sample: bool = Field(True, description="Whether to use sampling.")
150
+ use_mirostat: bool = Field(False, description="Whether to use Mirostat sampling.")
151
+ mirostat_tau: float = Field(5.0, ge=0.0, description="Mirostat tau.")
152
+ mirostat_eta: float = Field(0.1, ge=0.0, description="Mirostat eta.")
153
+ max_new_tokens: int = Field(MAX_GENERATION_TOKENS, ge=1, description="Max new tokens.")
154
+ system_prompt: Optional[str] = Field(None, description="Override the default system prompt.")
155
+ seed: Optional[int] = Field(None, description="Random seed.")
156
+ stop_sequences: Optional[List[str]] = Field(None, description="List of stop strings.", examples=[[".", "\nUsuario:"]])
157
+ tokenize_only: bool = Field(False, description="If true, only tokenize input.")
158
+ strip_trailing_whitespace: bool = Field(False, description="Strip trailing whitespace.")
159
+ remove_incomplete_sentences: bool = Field(False, description="Remove incomplete last sentence.")
160
+ num_return_sequences: int = Field(1, ge=1, le=5, description="Number of sequences to return (non-streaming).")
161
+ bad_words_ids: Optional[List[List[int]]] = Field(None, description="List of bad word token ids.", examples=[[[32000], [32001]]])
162
+ forced_bos_token_id: Optional[int] = Field(None, description="Forced BOS token id.")
163
+ forced_eos_token_id: Optional[int] = Field(None, description="Forced EOS token id.")
164
+ renormalize_logits: Optional[bool] = Field(None, description="Renormalize logits.")
165
+ suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress.")
166
+ begin_suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress at beginning.")
167
+ end_suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress at end.")
168
+ encoder_no_repeat_ngram_size: int = Field(0, ge=0, description="Encoder no repeat ngram size.")
169
+ min_length: int = Field(0, ge=0, description="Minimum total length.")
170
+ max_length: Optional[int] = Field(None, description="Maximum total length.")
171
+ exponential_decay_length_penalty: Optional[Tuple[float, int, float]] = Field(None, description="Exponential decay length penalty.")
172
+ use_cache: bool = Field(True, description="Use cache.")
173
+ typical_p: float = Field(1.0, ge=0.0, le=1.0, description="Typical P sampling.")
174
+ epsilon_cutoff: float = Field(0.0, ge=0.0, description="Epsilon cutoff for LTS.")
175
+ eta_cutoff: float = Field(0.0, ge=0.0, description="Eta cutoff for LTS.")
176
+ temperature_cutoff: Optional[float] = Field(None, ge=0.0, description="Temperature cutoff.")
177
+ encoder_repetition_penalty: float = Field(1.0, ge=0.0, description="Encoder repetition penalty.")
178
+ max_time: Optional[float] = Field(None, ge=0.0, description="Maximum time in seconds.")
179
+ output_watermark: bool = Field(False, description="Output watermark.")
180
+ remove_input_from_output: bool = Field(False, description="Remove input from output.")
181
+ eos_token_id_override: Optional[int] = Field(None, description="Override EOS token id.")
182
+ pad_token_id_override: Optional[int] = Field(None, description="Override PAD token id.")
183
+ bos_token_id_override: Optional[int] = Field(None, description="Override BOS token id.")
184
+ repetition_penalty_range: Optional[int] = Field(None, ge=0, description="Repetition penalty range.")
185
+ diversity_penalty: float = Field(0.0, ge=0.0, description="Diversity penalty for diverse beam search.")
186
+ num_beam_groups: int = Field(1, ge=1, description="Number of beam groups for diverse beam search.")
187
+ return_dict_in_generate: bool = Field(False, description="Return dictionary from generate.")
188
+ output_attentions: bool = Field(False, description="Output attentions.")
189
+ output_hidden_states: bool = Field(False, description="Output hidden states.")
190
+ output_scores: bool = Field(False, description="Output scores.")
191
+ return_token_logprobs: bool = Field(False, description="Return token logprobs in stream.")
192
+ return_text_from_sequence: bool = Field(True, description="Decode generated sequence to text.")
193
+ length_normalization_factor: Optional[float] = Field(None, description="Length normalization factor for beam search.")
194
+ min_new_tokens: int = Field(0, ge=0, description="Minimum number of new tokens.")
195
+ do_normalize_logits: bool = Field(False, description="Normalize logits.")
196
+ return_generation_inputs: bool = Field(False, description="Return generation inputs.")
197
+ return_unused_generate_parameters: bool = Field(False, description="Return unused generate parameters.")
198
+ use_fast_tokenizer: bool = Field(True, description="Use fast tokenizer if available.")
199
+ model_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional model kwargs for generate.")
200
+ tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs for encode.")
201
+ return_only_text: bool = Field(False, description="If true, only return the generated text.")
202
+
203
+ @validator('stop_sequences')
204
+ def validate_stop_sequences(cls, v):
205
+ if v is not None:
206
+ if not all(isinstance(seq, str) for seq in v):
207
+ raise ValueError('Each stop sequence must be a string')
208
+ return v
209
+
210
+ @validator('bad_words_ids')
211
+ def validate_bad_words_ids(cls, v):
212
+ if v is not None:
213
+ if not all(isinstance(word_id_list, list) and all(isinstance(token_id, int) for token_id in word_id_list) for word_id_list in v):
214
+ raise ValueError('bad_words_ids must be a list of lists of integers')
215
+ return v
216
+
217
+ @validator('exponential_decay_length_penalty')
218
+ def validate_exponential_decay_length_penalty(cls, v):
219
+ if v is not None:
220
+ if not (isinstance(v, (list, tuple)) and len(v) == 3 and
221
+ isinstance(v[0], (int, float)) and v[0] > 0 and
222
+ isinstance(v[1], int) and v[1] >= 0 and
223
+ isinstance(v[2], (int, float))):
224
+ raise ValueError('exponential_decay_length_penalty must be a tuple/list of 3 numbers (decay_factor, start_index, threshold)')
225
+ return v
226
+
227
+ class TokenizeRequest(BaseModel):
228
+ text: Union[str, List[str]] = Field(..., description="Text or list of texts to tokenize.")
229
+ add_special_tokens: bool = Field(True, description="Whether to add special tokens.")
230
+ is_split_into_words: bool = Field(False, description="Whether the input text is pre-tokenized.")
231
+ return_token_type_ids: bool = Field(False, description="Whether to return token type IDs.")
232
+ padding: Union[bool, str] = Field(False, description="Enable padding.")
233
+ truncation: Union[bool, str] = Field(False, description="Enable truncation.")
234
+ max_length: Optional[int] = Field(None, ge=1, description="Maximum length for padding and truncation.")
235
+ return_tensors: Optional[str] = Field(None, description="The type of tensors to return.")
236
+ return_attention_mask: Optional[bool] = Field(None, description="Whether to return the attention mask.")
237
+ return_offsets_mapping: Optional[bool] = Field(None, description="Whether to return offsets mapping.")
238
+ return_length: Optional[bool] = Field(None, description="Whether to return the length.")
239
+ verbose: bool = Field(False, description="Verbose tokenizer output.")
240
+ tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs.")
241
+
242
+ class DecodeRequest(BaseModel):
243
+ token_ids: List[int] = Field(..., description="List of token IDs to decode.", examples=[[1, 2, 3]])
244
+ skip_special_tokens: bool = Field(True, description="Skip special tokens.")
245
+ clean_up_tokenization_spaces: bool = Field(True, description="Clean up spaces.")
246
+ decode_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional decode kwargs.")
247
+
248
+ class SystemPromptUpdateRequest(BaseModel):
249
+ system_prompt: str = Field(..., description="The new global system prompt.")
250
+
251
+ class ModelReloadRequest(BaseModel):
252
+ model_name: Optional[str] = Field(None, description="New model name.")
253
+ trust_remote_code: Optional[bool] = Field(None, description="Override trust_remote_code.")
254
+ enable_flash_attention_2: Optional[bool] = Field(None, description="Override enable_flash_attention_2.")
255
+ torch_dtype: Optional[str] = Field(None, description="Override torch_dtype.")
256
+ model_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional model kwargs for from_pretrained().")
257
+ tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs for from_pretrained().")
258
+
259
+ def format_conversation(input_text: str, history: Optional[List[Dict[str, str]]], system_prompt: Optional[str]) -> str:
260
+ full_history: List[Dict[str, str]] = []
261
+ used_system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT
262
+ if not history or history[0].get("role") != "system" or history[0].get("content") != used_system_prompt:
263
+ full_history.append({"role": "system", "content": used_system_prompt})
264
+ if history:
265
+ full_history.extend(history)
266
+ if not full_history or full_history[-1].get("role") != "user" or full_history[-1].get("content") != input_text:
267
+ full_history.append({"role": "user", "content": input_text})
268
+
269
+ if global_tokenizer and hasattr(global_tokenizer, 'apply_chat_template') and global_tokenizer.chat_template:
270
+ try:
271
+ return global_tokenizer.apply_chat_template(full_history, tokenize=False, add_generation_prompt=True)
272
+ except Exception as e:
273
+ logging.error(f"Failed to apply chat template: {e}. Falling back to manual formatting.")
274
+ pass
275
+ formatted_text = ""
276
+ for i, message in enumerate(full_history):
277
+ if i == 0 and message["role"] == "system" and len(full_history) > 1 and full_history[1].get("role") == "system":
278
+ continue
279
+ if message["role"] == "system":
280
+ formatted_text += f"{message['content'].strip()}\n\n"
281
+ elif message["role"] == "user":
282
+ formatted_text += f"Usuario: {message['content'].strip()}\n"
283
+ elif message["role"] == "assistant":
284
+ formatted_text += f"Bot: {message['content'].strip()}\n"
285
+ if not formatted_text.endswith("Bot:"):
286
+ formatted_text += "Bot:"
287
+ return formatted_text.strip()
288
+
289
+ def truncate_encoded_ids(input_ids: torch.Tensor, max_length: int) -> torch.Tensor:
290
+ if input_ids.shape[-1] > max_length:
291
+ return input_ids[:, -max_length:]
292
+ return input_ids
293
+
294
+ def apply_seed(seed: Optional[int]):
295
+ if seed is not None:
296
+ torch.manual_seed(seed)
297
+ random.seed(seed)
298
+ if torch.cuda.is_available():
299
+ torch.cuda.manual_seed_all(seed)
300
+
301
+ def get_stopping_criteria(req: GenerateRequest, initial_ids: torch.Tensor, tokenizer: AutoTokenizer) -> StoppingCriteriaList:
302
+ criteria = StoppingCriteriaList()
303
+ max_len_from_req = None
304
+ if req.max_length is not None and req.max_length > 0:
305
+ max_len_from_req = req.max_length
306
+ elif req.max_new_tokens is not None and req.max_new_tokens > 0:
307
+ max_len_from_req = initial_ids.shape[-1] + req.max_new_tokens
308
+ else:
309
+ max_len_from_req = initial_ids.shape[-1] + MAX_GENERATION_TOKENS
310
+ if max_len_from_req is not None and max_len_from_req > 0:
311
+ criteria.append(MaxLengthCriteria(max_len_from_req))
312
+ if req.min_length is not None and req.min_length > 0:
313
+ eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id", -1)
314
+ criteria.append(MinLengthLogitsProcessor(initial_ids.shape[-1] + req.min_length, eos_token_id))
315
+ if req.stop_sequences:
316
+ criteria.append(StopSequenceCriteria(req.stop_sequences, tokenizer))
317
+ return criteria
318
+
319
+ def generate_next_token_sync(
320
+ input_ids,
321
+ past_key_values,
322
+ gen_cfg: GenerationConfig,
323
+ device: str
324
+ ) -> Tuple[torch.Tensor, Any, Optional[float], Optional[torch.Tensor], Any, Any]:
325
+ with torch.no_grad():
326
+ outputs = global_model(
327
+ input_ids, past_key_values=past_key_values,
328
+ use_cache=gen_cfg.use_cache, return_dict=True,
329
+ output_attentions=gen_cfg.output_attentions,
330
+ output_hidden_states=gen_cfg.output_hidden_states,
331
+ output_scores=gen_cfg.output_scores,
332
+ )
333
+ logits = outputs.logits[:, -1, :]
334
+ past = outputs.past_key_values
335
+ scores = outputs.scores if gen_cfg.output_scores else None
336
+ attentions = outputs.attentions if gen_cfg.output_attentions else None
337
+ hidden_states = outputs.hidden_states if gen_cfg.output_hidden_states else None
338
+ step_logits_for_criteria = logits.clone()
339
+ if gen_cfg.do_normalize_logits:
340
+ logits = F.log_softmax(logits, dim=-1)
341
+ if gen_cfg.do_sample:
342
+ if gen_cfg.use_mirostat_mode == 1 and hasattr(global_model, 'mirostat_sample_logits'):
343
+ token = global_model.mirostat_sample_logits(
344
+ logits=logits,
345
+ temperature=gen_cfg.temperature,
346
+ mirostat_tau=gen_cfg.mirostat_tau,
347
+ mirostat_eta=gen_cfg.mirostat_eta
348
+ ).unsqueeze(0).to(device)
349
+ else:
350
+ logits = logits / gen_cfg.temperature
351
+ if gen_cfg.temperature_cutoff is not None and gen_cfg.temperature_cutoff > 0:
352
+ logits = torch.where(logits < gen_cfg.temperature_cutoff, torch.tensor(-float('Inf')).to(logits.device), logits)
353
+ if gen_cfg.top_k:
354
+ topk_values, topk_indices = torch.topk(logits, gen_cfg.top_k)
355
+ logits[logits < topk_values[:, -1]] = -float('Inf')
356
+ if gen_cfg.top_p < 1.0:
357
+ sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
358
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
359
+ sorted_indices_to_remove = cumulative_probs > gen_cfg.top_p
360
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
361
+ sorted_indices_to_remove[..., 0] = False
362
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
363
+ logits[:, indices_to_remove] = -float('Inf')
364
+ if gen_cfg.typical_p < 1.0:
365
+ probs = torch.softmax(logits, dim=-1)
366
+ entropy = torch.distributions.Categorical(probs).entropy()
367
+ probs_sorted, indices_sorted = torch.sort(probs, dim=-1, descending=True)
368
+ cumsum_probs_sorted = torch.cumsum(probs_sorted, dim=-1)
369
+ mask = cumsum_probs_sorted < gen_cfg.typical_p * entropy.exp()
370
+ indices_to_remove = indices_sorted[~mask]
371
+ logits[:, indices_to_remove] = -float('Inf')
372
+ if gen_cfg.epsilon_cutoff is not None and gen_cfg.epsilon_cutoff > 0:
373
+ probs = torch.softmax(logits, dim=-1)
374
+ mask = probs < gen_cfg.epsilon_cutoff
375
+ logits[:, mask] = -float('Inf')
376
+ if gen_cfg.eta_cutoff is not None and gen_cfg.eta_cutoff > 0:
377
+ probs = torch.softmax(logits, dim=-1)
378
+ mask = probs > gen_cfg.eta_cutoff
379
+ logits[:, ~mask] = -float('Inf')
380
+ probs = torch.softmax(logits, dim=-1)
381
+ token = torch.multinomial(probs, 1)
382
+ else:
383
+ token = torch.argmax(logits, dim=-1, keepdim=True)
384
+ token_logprob = None
385
+ if gen_cfg.output_scores:
386
+ log_probs = F.log_softmax(step_logits_for_criteria, dim=-1)
387
+ if 0 <= token.squeeze().item() < log_probs.shape[-1]:
388
+ token_logprob = float(log_probs[:, token.squeeze()].item())
389
+ else:
390
+ token_logprob = None
391
+ return token, past, token_logprob, step_logits_for_criteria, attentions, hidden_states
392
+
393
+ def post_process_text(text: str, strip_trailing_whitespace: bool, remove_incomplete_sentences: bool) -> str:
394
+ if strip_trailing_whitespace:
395
+ text = text.rstrip()
396
+ if remove_incomplete_sentences:
397
+ for terminator in ['.', '!', '?', '\n']:
398
+ last_terminator = text.rfind(terminator)
399
+ if last_terminator != -1:
400
+ text = text[:last_terminator + 1]
401
+ break
402
+ return text
403
+
404
+ async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> AsyncGenerator[Union[str, Tuple[Dict[str, Any], str]], None]:
405
+ past = None
406
+ generated_tokens_count = 0
407
+ eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id")
408
+ pad_token_id = req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id", eos_token_id)
409
+ stop_token_ids = {eos_token_id} if eos_token_id is not None else set()
410
+ if pad_token_id is not None and pad_token_id != eos_token_id:
411
+ stop_token_ids.add(pad_token_id)
412
+
413
+ current_ids = initial_ids
414
+ start_time = time.time()
415
+ total_ids_list = initial_ids.tolist()[0]
416
+ finish_reason = "unknown"
417
+
418
+ stopping_criteria = get_stopping_criteria(req, initial_ids, global_tokenizer)
419
+
420
+ last_step_logits = None
421
+ accumulated_text_for_processing = ""
422
+
423
+ try:
424
+ while True:
425
+ if generated_tokens_count >= req.max_new_tokens:
426
+ finish_reason = "max_new_tokens"
427
+ break
428
+ if req.max_time is not None and (time.time() - start_time) > req.max_time:
429
+ finish_reason = "time"
430
+ break
431
+
432
+ input_ids_sync = current_ids if past is None else token
433
+
434
+ token, past, token_logprob, step_logits, attentions, hidden_states = await asyncio.to_thread(
435
+ generate_next_token_sync,
436
+ input_ids_sync,
437
+ past,
438
+ gen_cfg,
439
+ device
440
+ )
441
+ last_step_logits = step_logits
442
+
443
+ generated_token_id = token[0].item()
444
+ total_ids_list.append(generated_token_id)
445
+
446
+ text = global_tokenizer.decode([generated_token_id], skip_special_tokens=True)
447
+ accumulated_text_for_processing += text
448
+
449
+ if req.return_only_text:
450
+ yield text
451
+ else:
452
+ chunk_payload: Dict[str, Any] = {
453
+ "type": "token",
454
+ "text": text,
455
+ "token_id": generated_token_id,
456
+ "generated_tokens_count": generated_tokens_count + 1,
457
+ }
458
+ if req.return_token_logprobs and token_logprob is not None:
459
+ chunk_payload["logprob"] = token_logprob
460
+
461
+ yield json.dumps(chunk_payload) + "\n"
462
+
463
+ if generated_token_id in stop_token_ids:
464
+ finish_reason = "eos_token"
465
+ break
466
+
467
+ current_full_ids_tensor = torch.tensor([total_ids_list], device=device)
468
+ if stopping_criteria(current_full_ids_tensor, step_logits):
469
+ finish_reason = "stopping_criteria"
470
+ current_len = len(total_ids_list)
471
+ initial_len = initial_ids.shape[-1]
472
+
473
+ max_len_crit_met = any(isinstance(c, MaxLengthCriteria) for c in stopping_criteria) and \
474
+ ( (req.max_new_tokens is not None and current_len >= (initial_len + req.max_new_tokens)) or
475
+ (req.max_length is not None and current_len >= req.max_length) )
476
+ stop_seq_crit_met = any(isinstance(c, StopSequenceCriteria) for c in stopping_criteria) and req.stop_sequences and \
477
+ any(seq in global_tokenizer.decode(total_ids_list[initial_len:], skip_special_tokens=True) for seq in req.stop_sequences)
478
+
479
+ if max_len_crit_met:
480
+ if req.max_new_tokens is not None and current_len >= (initial_len + req.max_new_tokens):
481
+ finish_reason = "max_new_tokens"
482
+ elif req.max_length is not None and current_len >= req.max_length:
483
+ finish_reason = "max_length"
484
+
485
+ if stop_seq_crit_met:
486
+ finish_reason = "stop_sequence"
487
+
488
+
489
+ break
490
+
491
+
492
+ current_ids = token
493
+ generated_tokens_count += 1
494
+
495
+ final_text_raw = global_tokenizer.decode(total_ids_list[initial_ids.shape[-1]:], skip_special_tokens=True)
496
+ if req.stop_sequences and finish_reason == "stop_sequence":
497
+ for stop_seq in req.stop_sequences:
498
+ if stop_seq and stop_seq in final_text_raw:
499
+ final_text_raw = final_text_raw.split(stop_seq, 1)[0]
500
+ break
501
+
502
+ final_text_processed = post_process_text(final_text_raw, req.strip_trailing_whitespace, req.remove_incomplete_sentences)
503
+
504
+
505
+ if not req.return_only_text:
506
+ final_payload: Dict[str, Any] = {
507
+ "type": "done",
508
+ "total_prompt_tokens": initial_ids.shape[-1],
509
+ "total_generated_tokens": generated_tokens_count,
510
+ "total_sequence_tokens": len(total_ids_list),
511
+ "final_text": final_text_processed,
512
+ "finish_reason": finish_reason
513
+ }
514
+ yield json.dumps(final_payload) + "\n"
515
+
516
+
517
+ except Exception as e:
518
+ logging.exception("Streaming generation error:")
519
+ if req.return_only_text:
520
+ yield f"Error: {e}\n"
521
+ else:
522
+ error_payload = {"type": "error", "message": str(e)}
523
+ yield json.dumps(error_payload) + "\n"
524
+
525
+ finally:
526
+ await cleanup(device)
527
+
528
+
529
+ async def non_stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> Dict[str, Any]:
530
+ try:
531
+ logits_processor_list = LogitsProcessorList()
532
+
533
+ stopping_criteria_list = get_stopping_criteria(req, initial_ids, global_tokenizer)
534
+
535
+
536
+ with torch.no_grad():
537
+ out = global_model.generate(
538
+ input_ids=initial_ids,
539
+ generation_config=gen_cfg,
540
+ return_dict_in_generate=True,
541
+ output_scores=req.output_scores,
542
+ output_attentions=req.output_attentions,
543
+ output_hidden_states=req.output_hidden_states,
544
+ num_return_sequences=req.num_return_sequences,
545
+ bad_words_ids=req.bad_words_ids,
546
+ suppress_tokens=req.suppress_tokens,
547
+ begin_suppress_tokens=req.begin_suppress_tokens,
548
+ end_suppress_tokens=req.end_suppress_tokens,
549
+ logits_processor=logits_processor_list if logits_processor_list else None,
550
+ stopping_criteria=stopping_criteria_list if stopping_criteria_list else None,
551
+ )
552
+
553
+ generated_data = []
554
+ for i in range(req.num_return_sequences):
555
+ if i >= len(out.sequences):
556
+ break
557
+
558
+ sequence = out.sequences[i]
559
+ start_index = initial_ids.shape[-1]
560
+ generated_ids_tensor = sequence[start_index:]
561
+ full_sequence_ids = sequence.tolist()
562
+
563
+ text = global_tokenizer.decode(generated_ids_tensor, skip_special_tokens=True)
564
+
565
+ if req.stop_sequences:
566
+ for stop_seq in req.stop_sequences:
567
+ if stop_seq and stop_seq in text:
568
+ text = text.split(stop_seq, 1)[0]
569
+ break
570
+
571
+ text = post_process_text(text, req.strip_trailing_whitespace, req.remove_incomplete_sentences)
572
+
573
+ finish_reason = "length"
574
+ eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id")
575
+ if len(generated_ids_tensor) > 0 and eos_token_id is not None and generated_ids_tensor[-1] == eos_token_id:
576
+ finish_reason = "eos_token"
577
+ elif len(generated_ids_tensor) >= gen_cfg.max_new_tokens:
578
+ finish_reason = "max_new_tokens"
579
+ elif req.max_length is not None and len(full_sequence_ids) >= req.max_length:
580
+ finish_reason = "max_length"
581
+ elif hasattr(out, 'max_time_exceeded') and out.max_time_exceeded:
582
+ finish_reason = "time"
583
+
584
+ if req.stop_sequences and finish_reason == "length":
585
+ decoded_full_output = global_tokenizer.decode(full_sequence_ids, skip_special_tokens=True)
586
+ if any(seq in decoded_full_output for seq in req.stop_sequences):
587
+ finish_reason = "stop_sequence"
588
+
589
+
590
+ item_data: Dict[str, Any] = {
591
+ "text": text if req.return_text_from_sequence else None,
592
+ "token_ids": generated_ids_tensor.tolist(),
593
+ "generated_tokens_count": len(generated_ids_tensor),
594
+ "finish_reason": finish_reason
595
+ }
596
+ if not req.remove_input_from_output:
597
+ item_data["full_sequence_token_ids"] = full_sequence_ids
598
+
599
+ if req.output_scores and hasattr(out, 'scores') and out.scores is not None:
600
+ item_data["scores"] = "Scores output needs custom handling (complex structure)."
601
+
602
+ if req.return_token_logprobs:
603
+ item_data["token_logprobs"] = "Token logprobs require parsing scores output which is complex for batched/beamed generation."
604
+
605
+ if req.output_attentions and hasattr(out, 'attentions') and out.attentions is not None:
606
+ item_data["attentions"] = "Attentions output needs custom handling (too large)."
607
+ if req.output_hidden_states and hasattr(out, 'hidden_states') and out.hidden_states is not None:
608
+ item_data["hidden_states"] = "Hidden states output needs custom handling (too large)."
609
+ if hasattr(out, 'watermark') and out.watermark is not None:
610
+ item_data["watermark"] = out.watermark[i] if isinstance(out.watermark, list) and len(out.watermark) > i else out.watermark
611
+
612
+
613
+ generated_data.append(item_data)
614
+
615
+
616
+ response_payload: Dict[str, Any] = {
617
+ "prompt_tokens": initial_ids.shape[-1],
618
+ "generated_sequences": generated_data,
619
+ }
620
+ if req.num_return_sequences == 1 and generated_data:
621
+ response_payload["total_tokens"] = response_payload["prompt_tokens"] + generated_data[0]["generated_tokens_count"]
622
+
623
+ if req.return_dict_in_generate:
624
+ raw_out_dict = {}
625
+ for key in out.keys():
626
+ if key not in ['sequences', 'scores', 'attentions', 'hidden_states', 'past_key_values', 'watermark', 'sequences_scores']:
627
+ value = out[key]
628
+ if isinstance(value, torch.Tensor):
629
+ raw_out_dict[key] = value.tolist()
630
+ else:
631
+ raw_out_dict[key] = value
632
+
633
+ response_payload["raw_generate_output"] = raw_out_dict
634
+
635
+ return response_payload
636
+
637
+ except Exception as e:
638
+ logging.exception("Non-streaming generation error:")
639
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")
640
+
641
+ async def cleanup(device: str):
642
+ if device == "cuda" and torch.cuda.is_available():
643
+ torch.cuda.empty_cache()
644
+ gc.collect()
645
+
646
+
647
+ @app.on_event("startup")
648
+ async def load_model():
649
+ global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR, TRUST_REMOTE_CODE_ENV
650
+
651
+ torch.set_num_threads(max(1, os.cpu_count() // 2))
652
+ torch.set_num_interop_threads(max(1, os.cpu_count() // 4))
653
+
654
+ torch.backends.cuda.preferred_linalg_backend = "fused" if torch.backends.cuda.is_built() else None
655
+ torch.backends.cudnn.benchmark = True if torch.cuda.is_available() else False
656
+
657
+ try:
658
+ TORCH_DTYPE = getattr(torch, TORCH_DTYPE_STR.lower(), torch.float32)
659
+ if TORCH_DTYPE != torch.float32:
660
+ logging.warning(f"Requested dtype {TORCH_DTYPE_STR} might not be fully performant on CPU. Using float32.")
661
+ TORCH_DTYPE = torch.float32
662
+ except AttributeError:
663
+ logging.warning(f"Invalid TORCH_DTYPE specified: {TORCH_DTYPE_STR}. Falling back to float32.")
664
+ TORCH_DTYPE = torch.float32
665
+
666
+ current_model_name = MODEL_NAME
667
+ current_trust_remote_code = TRUST_REMOTE_CODE_ENV or (current_model_name == DEFAULT_MODEL_NAME)
668
+ device = "cpu"
669
+
670
+ try:
671
+ logging.info(f"Loading config for model: {current_model_name}")
672
+ config = AutoConfig.from_pretrained(current_model_name, trust_remote_code=current_trust_remote_code)
673
+ original_config = copy.deepcopy(config)
674
+
675
+ logging.info(f"Modifying config for simplified model.")
676
+
677
+ if hasattr(config, 'num_hidden_layers'):
678
+ config.num_hidden_layers = 1
679
+ elif hasattr(config, 'num_layers'):
680
+ config.num_layers = 1
681
+
682
+ if hasattr(config, 'bos_token_id'):
683
+ config.bos_token_id = 1
684
+
685
+ if hasattr(config, 'do_sample'):
686
+ config.do_sample = None
687
+
688
+ if hasattr(config, 'eos_token_id'):
689
+ config.eos_token_id = 2
690
+
691
+ if hasattr(config, 'head_dim'):
692
+ config.head_dim = 96
693
+
694
+ if hasattr(config, 'hidden_size'):
695
+ config.hidden_size = 192
696
+
697
+ if hasattr(config, 'initializer_range'):
698
+ config.initializer_range = 0.02
699
+
700
+ if hasattr(config, 'intermediate_size'):
701
+ config.intermediate_size = 512
702
+
703
+ if hasattr(config, 'max_position_embeddings'):
704
+ config.max_position_embeddings = MAX_CONTEXT_TOKENS
705
+
706
+ if hasattr(config, 'n_positions'):
707
+ config.n_positions = MAX_CONTEXT_TOKENS
708
+
709
+ if hasattr(config, 'seq_len'):
710
+ config.seq_len = MAX_CONTEXT_TOKENS
711
+
712
+ if hasattr(config, 'ctx'):
713
+ config.ctx = MAX_CONTEXT_TOKENS
714
+
715
+ if hasattr(config, 'n_ctx'):
716
+ config.n_ctx = MAX_CONTEXT_TOKENS
717
+
718
+ if hasattr(config, 'max_seq_length'):
719
+ config.max_seq_length = MAX_CONTEXT_TOKENS
720
+
721
+ if hasattr(config, 'max_sequence_length'):
722
+ config.max_sequence_length = MAX_CONTEXT_TOKENS
723
+
724
+ if hasattr(config, 'max_length'):
725
+ config.max_length = MAX_CONTEXT_TOKENS
726
+
727
+ if hasattr(config, 'block_size'):
728
+ config.block_size = MAX_CONTEXT_TOKENS
729
+
730
+ if hasattr(config, 'use_cache'):
731
+ config.use_cache = False
732
+
733
+ if hasattr(config, 'gradient_checkpointing'):
734
+ config.gradient_checkpointing = True
735
+
736
+ if hasattr(config, 'torch_dtype'):
737
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16:
738
+ config.torch_dtype = 'bfloat16'
739
+ else:
740
+ config.torch_dtype = 'float16'
741
+
742
+ if hasattr(config, 'use_bfloat16'):
743
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16:
744
+ config.use_bfloat16 = True
745
+ else:
746
+ config.use_bfloat16 = False
747
+
748
+ if hasattr(config, 'attention_probs_dropout_prob'):
749
+ config.attention_probs_dropout_prob = 0.1
750
+
751
+ if hasattr(config, 'hidden_dropout_prob'):
752
+ config.hidden_dropout_prob = 0.1
753
+
754
+ if hasattr(config, 'layerdrop'):
755
+ config.layerdrop = 0.1
756
+
757
+ if hasattr(config, 'layer_norm_eps'):
758
+ config.layer_norm_eps = 1e-5
759
+
760
+ if hasattr(config, 'initializer_range'):
761
+ config.initializer_range = 0.02
762
+
763
+ if hasattr(config, 'rotary_pct'):
764
+ config.rotary_pct = 0.25
765
+
766
+ if hasattr(config, 'rotary_emb_base'):
767
+ config.rotary_emb_base = 10000
768
+
769
+ if hasattr(config, 'position_embedding_type'):
770
+ config.position_embedding_type = 'rotary'
771
+
772
+ if hasattr(config, 'activation_function'):
773
+ config.activation_function = 'gelu_new'
774
+
775
+ if hasattr(config, 'vocab_size'):
776
+ config.vocab_size = 32000
777
+
778
+ if hasattr(config, 'quantization_config'):
779
+ if torch.cuda.is_available():
780
+ config.quantization_config = {
781
+ 'load_in_8bit': True,
782
+ 'load_in_4bit': False,
783
+ 'bnb_4bit_compute_dtype':'float16',
784
+ 'bnb_4bit_use_double_quant':True,
785
+ 'bnb_4bit_quant_type':'nf4'
786
+ }
787
+ else:
788
+ logging.warning("Quantization config requested but CUDA not available. Skipping quantization config modification.")
789
+ config.quantization_config = {}
790
+
791
+ if hasattr(config, 'load_in_8bit'):
792
+ if torch.cuda.is_available():
793
+ config.load_in_8bit = True
794
+ else:
795
+ config.load_in_8bit = False
796
+
797
+ if hasattr(config, 'load_in_4bit'):
798
+ if torch.cuda.is_available():
799
+ config.load_in_4bit = False
800
+ else:
801
+ config.load_in_4bit = False
802
+
803
+ if hasattr(config, 'tie_word_embeddings'):
804
+ config.tie_word_embeddings = True
805
+
806
+ if hasattr(config, 'output_attentions'):
807
+ config.output_attentions = False
808
+
809
+ if hasattr(config, 'output_hidden_states'):
810
+ config.output_hidden_states = False
811
+
812
+ if hasattr(config, 'use_cache'):
813
+ config.use_cache = False
814
+
815
+ logging.info(f"Loading tokenizer for model: {current_model_name}")
816
+ tokenizer_kwargs = {"config": original_config, "trust_remote_code": current_trust_remote_code}
817
+ global_tokenizer = AutoTokenizer.from_pretrained(current_model_name, **tokenizer_kwargs)
818
+ logging.info("Tokenizer loaded.")
819
+
820
+ logging.info(f"Loading model: {current_model_name} with modified config and dtype {TORCH_DTYPE} onto {device}")
821
+
822
+ model_kwargs = {"config": config, "torch_dtype": TORCH_DTYPE, "trust_remote_code": current_trust_remote_code}
823
+
824
+ global_model = AutoModelForCausalLM.from_pretrained(current_model_name, **model_kwargs)
825
+ global_model.to(device)
826
+
827
+ try:
828
+ global_model = torch.compile(global_model, mode="max-autotune")
829
+ logging.info("Model compiled with torch.compile (max-autotune mode).")
830
+ except Exception as e:
831
+ logging.warning(f"Failed to compile model with torch.compile: {e}")
832
+ pass
833
+
834
+ global_model.eval()
835
+ logging.info("Model loaded successfully.")
836
+
837
+ global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
838
+ global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
839
+ if global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is not None:
840
+ global_tokens["pad_token_id"] = global_tokens["eos_token_id"]
841
+ if global_model.config.pad_token_id is None:
842
+ global_model.config.pad_token_id = global_tokens["pad_token_id"]
843
+ elif global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is None:
844
+ logging.warning("Neither EOS nor PAD token is defined for this tokenizer/model.")
845
+ if global_model.config.pad_token_id is None and global_tokens.get("pad_token_id") is not None:
846
+ global_model.config.pad_token_id = global_tokens["pad_token_id"]
847
+
848
+ except Exception as e:
849
+ logging.exception("Failed to load model or tokenizer:")
850
+ global_model = None
851
+ global_tokenizer = None
852
+ global_tokens = {}
853
+
854
+ html_code = """
855
+ <!DOCTYPE html>
856
+ <html lang="es">
857
+ <head>
858
+ <meta charset="UTF-8" />
859
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
860
+ <title>Chatbot Profesional</title>
861
+ <style>
862
+ body { font-family: Arial, sans-serif; margin: 20px; }
863
+ #chatbox { width: 100%; height: 400px; border: 1px solid #ccc; padding: 10px; overflow-y: scroll; margin-bottom: 10px; }
864
+ #user-input { width: calc(100% - 100px); padding: 8px; box-sizing: border-box;}
865
+ #send-btn { width: 90px; padding: 8px 0;}
866
+ #input-area { display: flex;}
867
+ </style>
868
+ </head>
869
+ <body>
870
+ <h1>Chatbot Profesional (POST API)</h1>
871
+ <div id="chatbox"></div>
872
+ <div id="input-area">
873
+ <input type="text" id="user-input" placeholder="Escribe tu mensaje aquí..." autocomplete="off"/>
874
+ <button id="send-btn">Enviar</button>
875
+ </div>
876
+ <script>
877
+ const chatbox = document.getElementById('chatbox');
878
+ const userInput = document.getElementById('user-input');
879
+ const sendBtn = document.getElementById('send-btn');
880
+
881
+ let conversationHistory = [];
882
+ const DEFAULT_SYSTEM_PROMPT = "Eres un asistente profesional y servicial.";
883
+ let currentSystemPrompt = DEFAULT_SYSTEM_PROMPT;
884
+ let botMessageElement = null;
885
+
886
+ function appendMessage(sender, text, isStreaming = false) {
887
+ let msg;
888
+ if (isStreaming && botMessageElement) {
889
+ botMessageElement.textContent += text;
890
+ } else {
891
+ msg = document.createElement('p');
892
+ msg.innerHTML = `<strong>${sender}:</strong> `;
893
+ const textNode = document.createTextNode(text);
894
+ msg.appendChild(textNode);
895
+ chatbox.appendChild(msg);
896
+ if (sender === 'Bot' && isStreaming) {
897
+ botMessageElement = textNode;
898
+ } else {
899
+ botMessageElement = null;
900
+ }
901
+ }
902
+ chatbox.scrollTop = chatbox.scrollHeight;
903
+ }
904
+
905
+ function updateHistory(role, content) {
906
+ conversationHistory.push({ "role": role, "content": content });
907
+ const maxHistorySize = 10;
908
+ if (conversationHistory.length > maxHistorySize * 2) {
909
+ conversationHistory = conversationHistory.slice(-(maxHistorySize * 2));
910
+ }
911
+ }
912
+
913
+ async function sendMessage() {
914
+ const text = userInput.value;
915
+ if (!text) {
916
+ return;
917
+ }
918
+ appendMessage('Usuario', text);
919
+ updateHistory("user", text);
920
+ userInput.value = '';
921
+ sendBtn.disabled = true;
922
+
923
+ botMessageElement = null;
924
+
925
+ const messagePayload = {
926
+ input_text: text,
927
+ history: conversationHistory,
928
+ system_prompt: currentSystemPrompt,
929
+ stream: true,
930
+ temperature: 1.0,
931
+ top_k: 50,
932
+ top_p: 1.0,
933
+ repetition_penalty: 1.0,
934
+ frequency_penalty: 0.0,
935
+ presence_penalty: 0.0,
936
+ num_beams: 1,
937
+ length_penalty: 1.0,
938
+ no_repeat_ngram_size: 0,
939
+ early_stopping: false,
940
+ do_sample: true,
941
+ use_mirostat: false,
942
+ mirostat_tau: 5.0,
943
+ mirostat_eta: 0.1,
944
+ max_new_tokens: 512,
945
+ num_return_sequences: 1,
946
+ return_token_logprobs: true
947
+ };
948
+
949
+ try {
950
+ const response = await fetch('/generate', {
951
+ method: 'POST',
952
+ headers: {
953
+ 'Content-Type': 'application/json',
954
+ // Add API Key header if needed
955
+ // 'X-API-Key': 'YOUR_API_KEY_HERE'
956
+ },
957
+ body: JSON.stringify(messagePayload),
958
+ });
959
+
960
+ if (!response.ok) {
961
+ const errorData = await response.json();
962
+ throw new Error(`API Error: ${response.status} ${response.statusText} - ${errorData.detail || errorData.error}`);
963
+ }
964
+
965
+ const reader = response.body.getReader();
966
+ const decoder = new TextDecoder();
967
+ let buffer = '';
968
+ let currentBotResponse = "";
969
+
970
+ while (true) {
971
+ const { value, done } = await reader.read();
972
+ if (done) break;
973
+
974
+ buffer += decoder.decode(value, { stream: true });
975
+
976
+ const lines = buffer.split('\n');
977
+ buffer = lines.pop();
978
+
979
+ for (const line of lines) {
980
+ if (line.trim() === '') continue;
981
+ try {
982
+ const data = JSON.parse(line);
983
+ if (data.type === 'token') {
984
+ currentBotResponse += data.text;
985
+ appendMessage('Bot', data.text, true);
986
+ console.log('Token:', data.token_id, 'Text:', data.text, 'Logprob:', data.logprob);
987
+ } else if (data.type === 'done') {
988
+ console.log('Generation done', data);
989
+ if (data.total_tokens !== undefined) {
990
+ appendMessage('System', `Generated ${data.total_tokens} tokens. Finish reason: ${data.finish_reason}`);
991
+ }
992
+ if (data.final_text !== undefined) {
993
+ updateHistory("assistant", data.final_text);
994
+ } else if (currentBotResponse) {
995
+ updateHistory("assistant", currentBotResponse);
996
+ }
997
+
998
+ } else if (data.type === 'error') {
999
+ appendMessage('Error', data.message);
1000
+ currentBotResponse = "";
1001
+ }
1002
+ } catch (e) {
1003
+ console.error('Failed to parse stream chunk:', e, line);
1004
+ appendMessage('Error', 'Failed to process stream.');
1005
+ currentBotResponse = "";
1006
+ reader.cancel();
1007
+ return;
1008
+ }
1009
+ }
1010
+ }
1011
+
1012
+ if (buffer.trim() !== '') {
1013
+ try {
1014
+ const data = JSON.parse(buffer);
1015
+ if (data.type === 'token') {
1016
+ currentBotResponse += data.text;
1017
+ appendMessage('Bot', data.text, true);
1018
+ console.log('Token:', data.token_id, 'Text:', data.text, 'Logprob:', data.logprob);
1019
+ } else if (data.type === 'done') {
1020
+ console.log('Generation done', data);
1021
+ if (data.total_tokens !== undefined) {
1022
+ appendMessage('System', `Generated ${data.total_tokens} tokens. Finish reason: ${data.finish_reason}`);
1023
+ }
1024
+ if (data.final_text !== undefined) {
1025
+ updateHistory("assistant", data.final_text);
1026
+ } else if (currentBotResponse) {
1027
+ updateHistory("assistant", currentBotResponse);
1028
+ }
1029
+ } else if (data.type === 'error') {
1030
+ appendMessage('Error', data.message);
1031
+ currentBotResponse = "";
1032
+ }
1033
+ } catch (e) {
1034
+ console.error('Failed to parse remaining buffer:', e, buffer);
1035
+ appendMessage('Error', 'Failed to process remaining stream data.');
1036
+ currentBotResponse = "";
1037
+ }
1038
+ }
1039
+
1040
+
1041
+ if (currentBotResponse && !botMessageElement) {
1042
+ updateHistory("assistant", currentBotResponse);
1043
+ }
1044
+ botMessageElement = null;
1045
+ currentBotResponse = "";
1046
+
1047
+
1048
+ } catch (error) {
1049
+ console.error('Send message error:', error);
1050
+ appendMessage('Error', error.message || 'An unknown error occurred.');
1051
+ botMessageElement = null;
1052
+ currentBotResponse = "";
1053
+ } finally {
1054
+ sendBtn.disabled = false;
1055
+ }
1056
+ }
1057
+
1058
+ sendBtn.onclick = sendMessage;
1059
+
1060
+ userInput.addEventListener('keypress', function(event) {
1061
+ if (event.key === 'Enter') {
1062
+ event.preventDefault();
1063
+ sendMessage();
1064
+ }
1065
+ });
1066
+
1067
+
1068
+ </script>
1069
+ </body>
1070
+ </html>
1071
+ """
1072
+
1073
+ @app.get("/", response_class=HTMLResponse, summary="Interactive HTML interface")
1074
+ async def root():
1075
+ return HTMLResponse(content=html_code)
1076
+
1077
+ async def check_health():
1078
+ model_loaded = global_model is not None
1079
+ tokenizer_loaded = global_tokenizer is not None
1080
+ status_data = {
1081
+ "model_loaded": model_loaded,
1082
+ "tokenizer_loaded": tokenizer_loaded,
1083
+ "status": "ok" if model_loaded and tokenizer_loaded else "loading model",
1084
+ "cuda_available": torch.cuda.is_available(),
1085
+ "cpu_cores": os.cpu_count(),
1086
+ "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS,
1087
+ "currently_running_generations": MAX_CONCURRENT_GENERATIONS - generation_semaphore._value,
1088
+ "available_slots": generation_semaphore._value,
1089
+ }
1090
+ if torch.cuda.is_available():
1091
+ device_count = torch.cuda.device_count()
1092
+ status_data["device_count"] = device_count
1093
+ status_data["devices"] = []
1094
+ for i in range(device_count):
1095
+ try:
1096
+ device_status = {
1097
+ "id": i,
1098
+ "name": torch.cuda.get_device_name(i),
1099
+ "total_memory_mib": round(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024), 2),
1100
+ "allocated_memory_mib": round(torch.cuda.memory_allocated(i) / (1024 * 1024), 2),
1101
+ "cached_memory_mib": round(torch.cuda.memory_reserved(i) / (1024 * 1024), 2),
1102
+ }
1103
+ status_data["devices"].append(device_status)
1104
+ except Exception as e:
1105
+ logging.error(f"Error getting GPU memory info for device {i}: {e}")
1106
+ status_data["devices"].append({"id": i, "error": str(e)})
1107
+ else:
1108
+ status_data["message"] = "CUDA not available. GPU resource info is not applicable."
1109
+ return status_data
1110
+
1111
+ async def get_config_data():
1112
+ torch_dtype_str_out = str(TORCH_DTYPE).split('.')[-1] if isinstance(TORCH_DTYPE, torch.dtype) else str(TORCH_DTYPE)
1113
+ return {
1114
+ "model_name": MODEL_NAME,
1115
+ "system_prompt_default": SYSTEM_PROMPT,
1116
+ "max_context_tokens": MAX_CONTEXT_TOKENS,
1117
+ "max_generation_tokens": MAX_GENERATION_TOKENS,
1118
+ "cuda_available": torch.cuda.is_available(),
1119
+ "model_loaded": global_model is not None,
1120
+ "tokenizer_loaded": global_tokenizer is not None,
1121
+ "max_concurrent_generations": MAX_CONCURRENT_GENERATIONS,
1122
+ "trust_remote_code_startup_env": TRUST_REMOTE_CODE_ENV,
1123
+ "trust_remote_code_effective": TRUST_REMOTE_CODE,
1124
+ "enable_flash_attention_2": ENABLE_FLASH_ATTENTION_2,
1125
+ "torch_dtype": torch_dtype_str_out,
1126
+ "eos_token_id": global_tokens.get("eos_token_id"),
1127
+ "pad_token_id": global_tokens.get("pad_token_id"),
1128
+ "bos_token_id": global_tokenizer.bos_token_id if global_tokenizer else None,
1129
+ "api_key_required": API_KEY is not None
1130
+ }
1131
+
1132
+ async def get_model_info_data():
1133
+ if global_model is None:
1134
+ return {"model_name": MODEL_NAME, "is_loaded": False, "message": "Model is not loaded."}
1135
+ try:
1136
+ config_dict = global_model.config.to_dict()
1137
+ keys_to_remove = ['torch_dtype', '_attn_implementation', 'architectures', 'id2label', 'label2id', 'torch_dtype']
1138
+ for key in keys_to_remove:
1139
+ config_dict.pop(key, None)
1140
+ return {
1141
+ "model_name": MODEL_NAME,
1142
+ "is_loaded": True,
1143
+ "device": str(global_model.device),
1144
+ "torch_dtype": str(global_model.dtype),
1145
+ "config": config_dict
1146
+ }
1147
+ except Exception as e:
1148
+ logging.exception("Error getting model info:")
1149
+ return {"model_name": MODEL_NAME, "is_loaded": True, "error": f"Error getting model info: {e}"}
1150
+
1151
+ async def internal_tokenize(text: Union[str, List[str]], add_special_tokens: bool = True, is_split_into_words: bool = False, return_token_type_ids: bool = False, padding: Union[bool, str] = False, truncation: Union[bool, str] = False, max_length: Optional[int] = None, return_tensors: Optional[str] = None, return_attention_mask: Optional[bool] = None, return_offsets_mapping: Optional[bool] = None, return_length: Optional[bool] = None, verbose: bool = False, tokenizer_kwargs: Optional[Dict[str, Any]] = None):
1152
+ if global_tokenizer is None:
1153
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tokenizer is not loaded.")
1154
+ try:
1155
+ tokenizer_kwargs_final = tokenizer_kwargs or {}
1156
+ return_tensors_final = return_tensors if return_tensors is not None else None
1157
+ if return_tensors_final is None and (return_attention_mask or return_offsets_mapping or return_length):
1158
+ return_tensors_final = "pt"
1159
+ encoded = global_tokenizer(
1160
+ text,
1161
+ add_special_tokens=add_special_tokens,
1162
+ return_token_type_ids=return_token_type_ids,
1163
+ padding=padding,
1164
+ truncation=truncation,
1165
+ max_length=max_length,
1166
+ is_split_into_words=is_split_into_words,
1167
+ return_tensors=return_tensors_final,
1168
+ return_attention_mask=return_attention_mask,
1169
+ return_offsets_mapping=return_offsets_mapping,
1170
+ return_length=return_length,
1171
+ verbose=verbose,
1172
+ **tokenizer_kwargs_final
1173
+ )
1174
+ return encoded
1175
+ except Exception as e:
1176
+ logging.exception("Tokenization error:")
1177
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Tokenization error: {e}")
1178
+
1179
+ async def internal_decode(token_ids: List[int], skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = True, decode_kwargs: Optional[Dict[str, Any]] = None):
1180
+ if global_tokenizer is None:
1181
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tokenizer is not loaded.")
1182
+ try:
1183
+ decode_kwargs_final = decode_kwargs or {}
1184
+ text = global_tokenizer.decode(
1185
+ token_ids,
1186
+ skip_special_tokens=skip_special_tokens,
1187
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
1188
+ **decode_kwargs_final
1189
+ )
1190
+ return {"text": text}
1191
+ except Exception as e:
1192
+ logging.exception("Decoding error:")
1193
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Decoding error: {e}")
1194
+
1195
+ def update_global_system_prompt(new_prompt: str):
1196
+ global SYSTEM_PROMPT
1197
+ if new_prompt is not None:
1198
+ SYSTEM_PROMPT = new_prompt.strip()
1199
+ return {"status": "success", "message": "Global system prompt updated"}
1200
+ else:
1201
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="System prompt cannot be null")
1202
+
1203
+ async def internal_reload_model(req: ModelReloadRequest):
1204
+ global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR, TRUST_REMOTE_CODE_ENV
1205
+ new_model_name = req.model_name if req.model_name else MODEL_NAME
1206
+ new_trust_remote_code = req.trust_remote_code if req.trust_remote_code is not None else (TRUST_REMOTE_CODE_ENV or (new_model_name == DEFAULT_MODEL_NAME))
1207
+ new_enable_flash_attention_2 = req.enable_flash_attention_2 if req.enable_flash_attention_2 is not None else ENABLE_FLASH_ATTENTION_2
1208
+ new_torch_dtype_str_req = req.torch_dtype if req.torch_dtype else TORCH_DTYPE_STR
1209
+ try:
1210
+ new_torch_dtype = getattr(torch, new_torch_dtype_str_req.lower())
1211
+ if new_torch_dtype != torch.float32:
1212
+ logging.warning(f"Requested dtype {new_torch_dtype_str_req} might not be fully performant on CPU. Using float32.")
1213
+ new_torch_dtype = torch.float32
1214
+ elif not isinstance(new_torch_dtype, torch.dtype):
1215
+ raise AttributeError
1216
+ new_torch_dtype_str = str(new_torch_dtype).split('.')[-1]
1217
+ except AttributeError:
1218
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid or unsupported torch_dtype: {new_torch_dtype_str_req}")
1219
+ device = "cpu"
1220
+ async def _reload():
1221
+ global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR
1222
+ logging.info(f"Attempting to load model: {new_model_name}")
1223
+ try:
1224
+ logging.info("Unloading current model...")
1225
+ await cleanup(device)
1226
+ if global_model is not None:
1227
+ del global_model
1228
+ global_model = None
1229
+ if global_tokenizer is not None:
1230
+ del global_tokenizer
1231
+ global_tokenizer = None
1232
+ global_tokens = {}
1233
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
1234
+ gc.collect()
1235
+ logging.info("Current model unloaded.")
1236
+ logging.info(f"Loading config for model: {new_model_name}")
1237
+ config = AutoConfig.from_pretrained(new_model_name, trust_remote_code=new_trust_remote_code)
1238
+ original_config = copy.deepcopy(config)
1239
+
1240
+ logging.info(f"Modifying config for simplified model.")
1241
+
1242
+ config_modifications = {
1243
+ 'num_hidden_layers': 1,
1244
+ 'num_layers': 1,
1245
+ 'bos_token_id': 1,
1246
+ 'do_sample': None,
1247
+ 'eos_token_id': 2,
1248
+ 'head_dim': 96,
1249
+ 'hidden_size': 192,
1250
+ 'initializer_range': 0.02,
1251
+ 'intermediate_size': 512,
1252
+ 'max_position_embeddings': MAX_CONTEXT_TOKENS,
1253
+ 'n_positions': MAX_CONTEXT_TOKENS,
1254
+ 'seq_len': MAX_CONTEXT_TOKENS,
1255
+ 'ctx': MAX_CONTEXT_TOKENS,
1256
+ 'n_ctx': MAX_CONTEXT_TOKENS,
1257
+ 'max_seq_length': MAX_CONTEXT_TOKENS,
1258
+ 'max_sequence_length': MAX_CONTEXT_TOKENS,
1259
+ 'max_length': MAX_CONTEXT_TOKENS,
1260
+ 'block_size': MAX_CONTEXT_TOKENS,
1261
+ 'use_cache': False,
1262
+ 'gradient_checkpointing': True,
1263
+ 'attention_probs_dropout_prob': 0.1,
1264
+ 'hidden_dropout_prob': 0.1,
1265
+ 'layerdrop': 0.1,
1266
+ 'layer_norm_eps': 1e-5,
1267
+ 'rotary_pct': 0.25,
1268
+ 'rotary_emb_base': 10000,
1269
+ 'position_embedding_type': 'rotary',
1270
+ 'activation_function': 'gelu_new',
1271
+ 'vocab_size': 32000,
1272
+ 'tie_word_embeddings': True,
1273
+ 'output_attentions': False,
1274
+ 'output_hidden_states': False,
1275
+ }
1276
+
1277
+ for attr, new_val in config_modifications.items():
1278
+ if hasattr(config, attr):
1279
+ if attr == 'torch_dtype':
1280
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16:
1281
+ setattr(config, attr, torch.bfloat16)
1282
+ else:
1283
+ setattr(config, attr, torch.float16)
1284
+ elif attr == 'use_bfloat16':
1285
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16:
1286
+ setattr(config, attr, True)
1287
+ else:
1288
+ setattr(config, attr, False)
1289
+ elif attr == 'quantization_config':
1290
+ if torch.cuda.is_available():
1291
+ setattr(config, attr, new_val)
1292
+ else:
1293
+ logging.warning(f"Quantization config requested for '{attr}' but CUDA not available. Skipping modification.")
1294
+ else:
1295
+ setattr(config, attr, new_val)
1296
+ elif attr in ['num_hidden_layers', 'num_layers', 'max_position_embeddings', 'n_positions', 'seq_len', 'ctx', 'n_ctx', 'max_seq_length', 'max_sequence_length', 'max_length', 'block_size']:
1297
+ logging.warning(f"Could not find a standard parameter '{attr}' in config for {new_model_name}. Max context/layer logic might not be fully effective.")
1298
+
1299
+
1300
+ logging.info(f"Loading tokenizer for model: {new_model_name}")
1301
+ tokenizer_kwargs = {"config": original_config, "trust_remote_code": new_trust_remote_code}
1302
+ if req.tokenizer_kwargs:
1303
+ tokenizer_kwargs.update(req.tokenizer_kwargs)
1304
+ tokenizer = AutoTokenizer.from_pretrained(new_model_name, **tokenizer_kwargs)
1305
+ logging.info("Tokenizer loaded.")
1306
+
1307
+ logging.info(f"Loading model: {new_model_name} with modified config and dtype {new_torch_dtype_str} onto {device}")
1308
+ model_kwargs = {"config": config, "torch_dtype": new_torch_dtype, "trust_remote_code": new_trust_remote_code}
1309
+ model = AutoModelForCausalLM.from_pretrained(new_model_name, **model_kwargs)
1310
+ model.to(device)
1311
+
1312
+ try:
1313
+ model = torch.compile(model, mode="max-autotune")
1314
+ logging.info("New model compiled with torch.compile (max-autotune mode).")
1315
+ except Exception as e:
1316
+ logging.warning(f"Failed to compile new model with torch.compile: {e}")
1317
+ pass
1318
+ model.eval()
1319
+ logging.info("New model loaded successfully.")
1320
+ global_model = model
1321
+ global_tokenizer = tokenizer
1322
+ global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
1323
+ global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
1324
+ if global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is not None:
1325
+ global_tokens["pad_token_id"] = global_tokens["eos_token_id"]
1326
+ if global_model.config.pad_token_id is None:
1327
+ global_model.config.pad_token_id = global_tokens["pad_token_id"]
1328
+ elif global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is None:
1329
+ logging.warning("Neither EOS nor PAD token defined for new model.")
1330
+ if global_model.config.pad_token_id is None and global_tokens.get("pad_token_id") is not None:
1331
+ global_model.config.pad_token_id = global_tokens["pad_token_id"]
1332
+ MODEL_NAME = new_model_name
1333
+ TRUST_REMOTE_CODE = new_trust_remote_code
1334
+ ENABLE_FLASH_ATTENTION_2 = new_enable_flash_attention_2
1335
+ TORCH_DTYPE = new_torch_dtype
1336
+ TORCH_DTYPE_STR = new_torch_dtype_str
1337
+ if hasattr(global_tokenizer, 'use_fast'):
1338
+ pass
1339
+ logging.info(f"Model successfully reloaded to: {MODEL_NAME}")
1340
+ logging.info({"status": "success", "message": f"Model {new_model_name} loaded successfully."})
1341
+ except Exception as e:
1342
+ logging.exception(f"Failed to load model {new_model_name}:")
1343
+ global_model = None
1344
+ global_tokenizer = None
1345
+ global_tokens = {}
1346
+ logging.error({"status": "error", "message": f"Failed to load model {new_model_name}: {e}. Model is now unloaded."})
1347
+ asyncio.create_task(_reload())
1348
+ return {"status": "info", "message": f"Attempting to load model {new_model_name} in background. Check logs for status."}
1349
+
1350
+ async def internal_unload_model():
1351
+ global global_model, global_tokenizer, global_tokens
1352
+ device = "cpu"
1353
+ logging.info("Attempting to unload model.")
1354
+ try:
1355
+ await cleanup(device)
1356
+ if global_model is not None:
1357
+ del global_model
1358
+ global_model = None
1359
+ if global_tokenizer is not None:
1360
+ del global_tokenizer
1361
+ global_tokenizer = None
1362
+ global_tokens = {}
1363
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
1364
+ gc.collect()
1365
+ logging.info("Model unloaded successfully.")
1366
+ return {"status": "success", "message": "Model unloaded successfully."}
1367
+ except Exception as e:
1368
+ logging.exception("Failed to unload model:")
1369
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to unload model: {e}")
1370
+
1371
+
1372
+ @app.post("/generate", summary="Generate text", dependencies=[Depends(get_api_key)])
1373
+ async def generate_endpoint(req: GenerateRequest):
1374
+ if global_model is None or global_tokenizer is None:
1375
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model is not loaded. It may still be loading or failed to load.")
1376
+ device = "cpu"
1377
+ apply_seed(req.seed)
1378
+ try:
1379
+ initial_prompt_text = format_conversation(req.input_text, req.history, req.system_prompt)
1380
+ except Exception as e:
1381
+ logging.exception("Error formatting conversation:")
1382
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Error formatting conversation: {e}")
1383
+ try:
1384
+ tokenizer_encoding_kwargs = req.tokenizer_kwargs or {}
1385
+
1386
+ encoded = global_tokenizer(initial_prompt_text, return_tensors="pt", add_special_tokens=False, **tokenizer_encoding_kwargs).to(device)
1387
+ initial_ids_before_trunc = encoded.input_ids
1388
+ initial_prompt_tokens_count_before_trunc = initial_ids_before_trunc.shape[-1]
1389
+
1390
+ ids = truncate_encoded_ids(initial_ids_before_trunc, MAX_CONTEXT_TOKENS)
1391
+ current_prompt_tokens_count = ids.shape[-1]
1392
+
1393
+ except Exception as e:
1394
+ logging.exception("Tokenizer error during encoding:")
1395
+ await cleanup(device)
1396
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Tokenizer encoding error: {e}")
1397
+ if req.tokenize_only:
1398
+ await cleanup(device)
1399
+ return JSONResponse({
1400
+ "prompt_tokens_count": initial_prompt_tokens_count_before_trunc,
1401
+ "max_context_tokens": MAX_CONTEXT_TOKENS,
1402
+ "truncated": initial_prompt_tokens_count_before_trunc > MAX_CONTEXT_TOKENS,
1403
+ "input_text_processed": initial_prompt_text,
1404
+ "input_ids_truncated": ids.tolist()[0]
1405
+ })
1406
+ total_capacity = MAX_CONTEXT_TOKENS + MAX_GENERATION_TOKENS
1407
+ total_requested_seq_len = current_prompt_tokens_count + req.max_new_tokens
1408
+ if not req.stream and total_requested_seq_len > total_capacity:
1409
+ await cleanup(device)
1410
+ raise HTTPException(
1411
+ status_code=status.HTTP_400_BAD_REQUEST,
1412
+ detail=f"Requested sequence length ({total_requested_seq_len} tokens = {current_prompt_tokens_count} prompt + {req.max_new_tokens} new) exceeds model capacity ({total_capacity} tokens) and non-streaming is requested. Consider enabling streaming or reducing max_new_tokens."
1413
+ )
1414
+ async with generation_semaphore:
1415
+ try:
1416
+ gen_cfg = GenerationConfig(
1417
+ temperature=req.temperature,
1418
+ top_k=req.top_k,
1419
+ top_p=req.top_p,
1420
+ repetition_penalty=req.repetition_penalty,
1421
+ frequency_penalty=req.frequency_penalty,
1422
+ presence_penalty=req.presence_penalty,
1423
+ num_beams=req.num_beams if not req.stream else 1,
1424
+ length_penalty=req.length_penalty,
1425
+ no_repeat_ngram_size=req.no_repeat_ngram_size,
1426
+ early_stopping=req.early_stopping,
1427
+ do_sample=req.do_sample,
1428
+ use_mirostat_mode=1 if req.use_mirostat else 0,
1429
+ mirostat_tau=req.mirostat_tau,
1430
+ mirostat_eta=req.mirostat_eta,
1431
+ max_new_tokens=req.max_new_tokens,
1432
+ eos_token_id=req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id"),
1433
+ pad_token_id=req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id"),
1434
+ bos_token_id=req.bos_token_id_override if req.bos_token_id_override is not None else global_tokenizer.bos_token_id,
1435
+ num_return_sequences=req.num_return_sequences if not req.stream else 1,
1436
+ bad_words_ids=req.bad_words_ids,
1437
+ forced_bos_token_id=req.forced_bos_token_id,
1438
+ forced_eos_token_id=req.forced_eos_token_id,
1439
+ renormalize_logits=req.renormalize_logits,
1440
+ suppress_tokens=req.suppress_tokens,
1441
+ begin_suppress_tokens=req.begin_suppress_tokens,
1442
+ end_suppress_tokens=req.end_suppress_tokens,
1443
+ encoder_no_repeat_ngram_size=req.encoder_no_repeat_ngram_size,
1444
+ min_length=req.min_length,
1445
+ max_length=req.max_length,
1446
+ exponential_decay_length_penalty=req.exponential_decay_length_penalty,
1447
+ use_cache=req.use_cache,
1448
+ typical_p=req.typical_p,
1449
+ epsilon_cutoff=req.epsilon_cutoff,
1450
+ eta_cutoff=req.eta_cutoff,
1451
+ temperature_cutoff=req.temperature_cutoff,
1452
+ encoder_repetition_penalty=req.encoder_repetition_penalty,
1453
+ max_time=req.max_time,
1454
+ output_watermark=req.output_watermark,
1455
+ diversity_penalty=req.diversity_penalty,
1456
+ num_beam_groups=req.num_beam_groups if not req.stream else 1,
1457
+ length_normalization_factor=req.length_normalization_factor,
1458
+ min_new_tokens=req.min_new_tokens,
1459
+ do_normalize_logits=req.do_normalize_logits,
1460
+ output_scores=req.output_scores,
1461
+ output_attentions=req.output_attentions,
1462
+ output_hidden_states=req.output_hidden_states,
1463
+ )
1464
+ if req.stream:
1465
+ gen_cfg.use_cache = True
1466
+ gen_cfg.num_beams = 1
1467
+ gen_cfg.num_return_sequences = 1
1468
+ gen_cfg.num_beam_groups = 1
1469
+ return StreamingResponse(stream_generation_logic(req, ids, gen_cfg, device), media_type="text/plain" if req.return_only_text else "application/json")
1470
+ else:
1471
+ response_payload = await non_stream_generation_logic(req, ids, gen_cfg, device)
1472
+ if req.return_only_text:
1473
+ texts = [seq["text"] for seq in response_payload.get("generated_sequences", []) if seq.get("text") is not None]
1474
+ if req.num_return_sequences == 1 and texts:
1475
+ return PlainTextResponse(texts[0])
1476
+ else:
1477
+ return JSONResponse(texts)
1478
+ else:
1479
+ return JSONResponse(response_payload)
1480
+ except Exception as e:
1481
+ logging.exception("Generation error:")
1482
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")
1483
+ finally:
1484
+ await cleanup(device)
1485
+
1486
+ if __name__ == "__main__":
1487
+ uvicorn.run(
1488
+ app, host="0.0.0.0", port=7860,
1489
+ log_level="critical",
1490
+ access_log=False
1491
+ )