kouki321 commited on
Commit
2a5a9e0
·
verified ·
1 Parent(s): e682f3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -125
app.py CHANGED
@@ -6,34 +6,20 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from transformers.cache_utils import DynamicCache , StaticCache
7
  from pydantic import BaseModel
8
  from typing import Optional
9
- import uvicorn
10
  import tempfile
11
  from time import time
12
  from fastapi.responses import RedirectResponse
13
 
14
-
15
  # Add necessary serialization safety
16
  torch.serialization.add_safe_globals([DynamicCache])
17
  torch.serialization.add_safe_globals([set])
18
- #These lines allow PyTorch to serialize and deserialize these objects without raising errors,
19
- # #ensuring compatibility and functionality during cache saving/loading.
20
 
21
- # Minimal generate function for token-by-token generation
22
- def generate(model,
23
- input_ids,
24
- past_key_values,
25
- max_new_tokens=50):
26
- """
27
- This function performs token-by-token text generation using a pre-trained language model.
28
- Purpose: To generate new text based on input tokens, without loading the full context repeatedly
29
- Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum
30
- Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens
31
- """
32
  device = model.model.embed_tokens.weight.device
33
- origin_len = input_ids.shape[-1]#Stores the length of the input sequence (number of tokens) before text generation begins./return only the newly
34
- input_ids = input_ids.to(device)#same device as the model.
35
- output_ids = input_ids.clone()#will be updated during the generation process to include newly generated tokens.
36
- next_token = input_ids#the token that will process in the next iteration.
37
  with torch.no_grad():
38
  for _ in range(max_new_tokens):
39
  out = model(
@@ -41,28 +27,19 @@ def generate(model,
41
  past_key_values=past_key_values,
42
  use_cache=True
43
  )
44
- logits = out.logits[:, -1, :]#Extracts the logits for the last token
45
- token = torch.argmax(logits, dim=-1, keepdim=True)#highest predicted probability as the next token.
46
- output_ids = torch.cat([output_ids, token], dim=-1)#add the newly generated token
47
  past_key_values = out.past_key_values
48
  next_token = token.to(device)
49
  if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
50
  break
51
- return output_ids[:, origin_len:] # Return just the newly generated part
52
-
53
  def get_kv_cache(model, tokenizer, prompt):
54
- """
55
- This function creates a key-value cache for a given prompt.
56
- Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt
57
- Process: Encodes the prompt, runs it through the model, and captures the resulting cache
58
- Returns: The cache object and the original prompt length for future reference
59
- """
60
- # Encode prompt
61
  device = model.model.embed_tokens.weight.device
62
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
63
- cache = DynamicCache() # it grows as text is generated
64
-
65
- # Run the model to populate the KV cache:
66
  with torch.no_grad():
67
  _ = model(
68
  input_ids=input_ids,
@@ -72,110 +49,74 @@ def get_kv_cache(model, tokenizer, prompt):
72
  return cache, input_ids.shape[-1]
73
 
74
  def clean_up(cache, origin_len):
75
- # Make a deep copy of the cache first
76
  new_cache = DynamicCache()
77
  for i in range(len(cache.key_cache)):
78
  new_cache.key_cache.append(cache.key_cache[i].clone())
79
  new_cache.value_cache.append(cache.value_cache[i].clone())
80
-
81
- # Remove any tokens appended to the original knowledge
82
  for i in range(len(new_cache.key_cache)):
83
  new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :]
84
  new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :]
85
  return new_cache
 
86
  os.environ["TRANSFORMERS_OFFLINE"] = "1"
87
  os.environ["HF_HUB_OFFLINE"] = "1"
88
 
89
- # Path to your local model
90
-
91
- # Initialize model and tokenizer
92
  def load_model_and_tokenizer():
93
- model_path = "./model"
94
-
95
- # Load tokenizer and model from disk (without trust_remote_code)
96
  tokenizer = AutoTokenizer.from_pretrained(model_path)
97
  if torch.cuda.is_available():
98
- # Load model on GPU if CUDA is available
99
  model = AutoModelForCausalLM.from_pretrained(
100
  model_path,
101
  torch_dtype=torch.float16,
102
- device_map="auto" # Automatically map model layers to GPU
103
  )
104
  else:
105
- # Load model on CPU if no GPU is available
106
  model = AutoModelForCausalLM.from_pretrained(
107
  model_path,
108
- torch_dtype=torch.float32, # Use float32 for compatibility with CPU
109
- low_cpu_mem_usage=True # Reduce memory usage on CPU
110
  )
111
  return model, tokenizer
112
 
113
- # Create FastAPI app
114
  app = FastAPI(title="DeepSeek QA with KV Cache API")
115
-
116
- # Global variables to store the cache, origin length, and model/tokenizer
117
  cache_store = {}
118
-
119
- # Initialize model and tokenizer at startup
120
  model, tokenizer = load_model_and_tokenizer()
121
 
122
  class QueryRequest(BaseModel):
123
  query: str
124
  max_new_tokens: Optional[int] = 150
 
125
  def clean_response(response_text):
126
- """
127
- Clean up model response by removing redundant tags, repetitions, and formatting issues.
128
- """
129
- # First, try to extract just the answer content between tags if they exist
130
  import re
131
-
132
- # Try to extract content between assistant tags if present
133
  assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL)
134
  matches = assistant_pattern.findall(response_text)
135
-
136
  if matches:
137
- # Return the first meaningful assistant response
138
  for match in matches:
139
  cleaned = match.strip()
140
  if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5:
141
  return cleaned
142
-
143
- # If no proper match found, try more aggressive cleaning
144
- # Remove all tag markers completely
145
  cleaned = re.sub(r'<\|.*?\|>', '', response_text)
146
  cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned)
147
-
148
- # Remove duplicate lines (common in generated responses)
149
  lines = cleaned.strip().split('\n')
150
  unique_lines = []
151
  for line in lines:
152
  line = line.strip()
153
  if line and line not in unique_lines:
154
  unique_lines.append(line)
155
-
156
  result = '\n'.join(unique_lines)
157
-
158
- # Final cleanup - remove any trailing system/user markers
159
  result = re.sub(r'<\/?\|.*?\|>\s*$', '', result)
160
-
161
  return result.strip()
 
162
  @app.post("/upload-document_to_create_KV_cache")
163
  async def upload_document(file: UploadFile = File(...)):
164
- """Upload a document and create KV cache for it"""
165
  t1 = time()
166
-
167
- # Save the uploaded file temporarily
168
  with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
169
  temp_file_path = temp_file.name
170
  content = await file.read()
171
  temp_file.write(content)
172
-
173
  try:
174
- # Read the document
175
  with open(temp_file_path, "r", encoding="utf-8") as f:
176
  doc_text = f.read()
177
-
178
- # Create system prompt with document context
179
  system_prompt = f"""
180
  <|system|>
181
  Answer concisely and precisely, You are an assistant who provides concise factual answers.
@@ -184,148 +125,104 @@ async def upload_document(file: UploadFile = File(...)):
184
  {doc_text}
185
  Question:
186
  """.strip()
187
-
188
- # Create KV cache
189
  cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
190
-
191
- # Generate a unique ID for this document/cache
192
  cache_id = f"cache_{int(time())}"
193
-
194
- # Store the cache and origin_len
195
  cache_store[cache_id] = {
196
  "cache": cache,
197
  "origin_len": origin_len,
198
  "doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
199
  }
200
-
201
- # Clean up the temporary file
202
  os.unlink(temp_file_path)
203
-
204
  t2 = time()
205
-
206
  return {
207
  "cache_id": cache_id,
208
  "message": "Document uploaded and cache created successfully",
209
  "doc_preview": cache_store[cache_id]["doc_preview"],
210
  "time_taken": f"{t2 - t1:.4f} seconds"
211
  }
212
-
213
  except Exception as e:
214
- # Clean up the temporary file in case of error
215
  if os.path.exists(temp_file_path):
216
  os.unlink(temp_file_path)
217
  raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
218
 
219
  @app.post("/generate_answer_from_cache/{cache_id}")
220
  async def generate_answer(cache_id: str, request: QueryRequest):
221
- """Generate an answer to a question based on the uploaded document"""
222
  t1 = time()
223
-
224
- # Check if the document/cache exists
225
  if cache_id not in cache_store:
226
  raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
227
-
228
  try:
229
- # Get a clean copy of the cache
230
  current_cache = clean_up(
231
- cache_store[cache_id]["cache"],
232
  cache_store[cache_id]["origin_len"]
233
  )
234
-
235
- # Prepare input with just the query
236
  full_prompt = f"""
237
  <|user|>
238
  Question: {request.query}
239
  <|assistant|>
240
  """.strip()
241
-
242
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
243
-
244
- # Generate response
245
  output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens)
246
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
247
- rep = clean_response(response)
248
  t2 = time()
249
-
250
  return {
251
  "query": request.query,
252
  "answer": rep,
253
  "time_taken": f"{t2 - t1:.4f} seconds"
254
  }
255
-
256
  except Exception as e:
257
  raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
258
 
259
  @app.post("/save_cache/{cache_id}")
260
  async def save_cache(cache_id: str):
261
- """Save the cache for a document"""
262
  if cache_id not in cache_store:
263
  raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
264
-
265
  try:
266
- # Clean up the cache and save it
267
  cleaned_cache = clean_up(
268
- cache_store[cache_id]["cache"],
269
  cache_store[cache_id]["origin_len"]
270
  )
271
-
272
  cache_path = f"{cache_id}_cache.pth"
273
  torch.save(cleaned_cache, cache_path)
274
-
275
  return {
276
  "message": f"Cache saved successfully as {cache_path}",
277
  "cache_path": cache_path
278
  }
279
-
280
  except Exception as e:
281
  raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}")
282
 
283
  @app.post("/load_cache")
284
  async def load_cache(file: UploadFile = File(...)):
285
- """Load a previously saved cache"""
286
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file:
287
  temp_file_path = temp_file.name
288
  content = await file.read()
289
  temp_file.write(content)
290
-
291
  try:
292
- # Load the cache
293
  loaded_cache = torch.load(temp_file_path)
294
-
295
- # Generate a unique ID for this cache
296
  cache_id = f"loaded_cache_{int(time())}"
297
-
298
- # Store the cache (we don't have the original document text)
299
  cache_store[cache_id] = {
300
  "cache": loaded_cache,
301
  "origin_len": loaded_cache.key_cache[0].shape[-2],
302
  "doc_preview": "Loaded from cache file"
303
  }
304
-
305
- # Clean up the temporary file
306
  os.unlink(temp_file_path)
307
-
308
  return {
309
  "cache_id": cache_id,
310
  "message": "Cache loaded successfully"
311
  }
312
-
313
  except Exception as e:
314
- # Clean up the temporary file in case of error
315
  if os.path.exists(temp_file_path):
316
  os.unlink(temp_file_path)
317
  raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}")
318
 
319
  @app.get("/list_of_caches")
320
  async def list_documents():
321
- """List all uploaded documents/caches"""
322
  documents = {}
323
  for cache_id in cache_store:
324
  documents[cache_id] = {
325
  "doc_preview": cache_store[cache_id]["doc_preview"],
326
  "origin_len": cache_store[cache_id]["origin_len"]
327
  }
328
-
329
  return {"documents": documents}
330
 
331
  @app.get("/", include_in_schema=False)
@@ -333,4 +230,4 @@ async def root():
333
  return RedirectResponse(url="/docs")
334
 
335
  if __name__ == "__main__":
336
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
6
  from transformers.cache_utils import DynamicCache , StaticCache
7
  from pydantic import BaseModel
8
  from typing import Optional
 
9
  import tempfile
10
  from time import time
11
  from fastapi.responses import RedirectResponse
12
 
 
13
  # Add necessary serialization safety
14
  torch.serialization.add_safe_globals([DynamicCache])
15
  torch.serialization.add_safe_globals([set])
 
 
16
 
17
+ def generate(model, input_ids, past_key_values, max_new_tokens=50):
 
 
 
 
 
 
 
 
 
 
18
  device = model.model.embed_tokens.weight.device
19
+ origin_len = input_ids.shape[-1]
20
+ input_ids = input_ids.to(device)
21
+ output_ids = input_ids.clone()
22
+ next_token = input_ids
23
  with torch.no_grad():
24
  for _ in range(max_new_tokens):
25
  out = model(
 
27
  past_key_values=past_key_values,
28
  use_cache=True
29
  )
30
+ logits = out.logits[:, -1, :]
31
+ token = torch.argmax(logits, dim=-1, keepdim=True)
32
+ output_ids = torch.cat([output_ids, token], dim=-1)
33
  past_key_values = out.past_key_values
34
  next_token = token.to(device)
35
  if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
36
  break
37
+ return output_ids[:, origin_len:]
38
+
39
  def get_kv_cache(model, tokenizer, prompt):
 
 
 
 
 
 
 
40
  device = model.model.embed_tokens.weight.device
41
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
42
+ cache = DynamicCache()
 
 
43
  with torch.no_grad():
44
  _ = model(
45
  input_ids=input_ids,
 
49
  return cache, input_ids.shape[-1]
50
 
51
  def clean_up(cache, origin_len):
 
52
  new_cache = DynamicCache()
53
  for i in range(len(cache.key_cache)):
54
  new_cache.key_cache.append(cache.key_cache[i].clone())
55
  new_cache.value_cache.append(cache.value_cache[i].clone())
 
 
56
  for i in range(len(new_cache.key_cache)):
57
  new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :]
58
  new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :]
59
  return new_cache
60
+
61
  os.environ["TRANSFORMERS_OFFLINE"] = "1"
62
  os.environ["HF_HUB_OFFLINE"] = "1"
63
 
 
 
 
64
  def load_model_and_tokenizer():
65
+ model_path = os.environ.get("MODEL_PATH", "./model") # allow override via Docker env
 
 
66
  tokenizer = AutoTokenizer.from_pretrained(model_path)
67
  if torch.cuda.is_available():
 
68
  model = AutoModelForCausalLM.from_pretrained(
69
  model_path,
70
  torch_dtype=torch.float16,
71
+ device_map="auto"
72
  )
73
  else:
 
74
  model = AutoModelForCausalLM.from_pretrained(
75
  model_path,
76
+ torch_dtype=torch.float32,
77
+ low_cpu_mem_usage=True
78
  )
79
  return model, tokenizer
80
 
 
81
  app = FastAPI(title="DeepSeek QA with KV Cache API")
 
 
82
  cache_store = {}
 
 
83
  model, tokenizer = load_model_and_tokenizer()
84
 
85
  class QueryRequest(BaseModel):
86
  query: str
87
  max_new_tokens: Optional[int] = 150
88
+
89
  def clean_response(response_text):
 
 
 
 
90
  import re
 
 
91
  assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL)
92
  matches = assistant_pattern.findall(response_text)
 
93
  if matches:
 
94
  for match in matches:
95
  cleaned = match.strip()
96
  if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5:
97
  return cleaned
 
 
 
98
  cleaned = re.sub(r'<\|.*?\|>', '', response_text)
99
  cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned)
 
 
100
  lines = cleaned.strip().split('\n')
101
  unique_lines = []
102
  for line in lines:
103
  line = line.strip()
104
  if line and line not in unique_lines:
105
  unique_lines.append(line)
 
106
  result = '\n'.join(unique_lines)
 
 
107
  result = re.sub(r'<\/?\|.*?\|>\s*$', '', result)
 
108
  return result.strip()
109
+
110
  @app.post("/upload-document_to_create_KV_cache")
111
  async def upload_document(file: UploadFile = File(...)):
 
112
  t1 = time()
 
 
113
  with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
114
  temp_file_path = temp_file.name
115
  content = await file.read()
116
  temp_file.write(content)
 
117
  try:
 
118
  with open(temp_file_path, "r", encoding="utf-8") as f:
119
  doc_text = f.read()
 
 
120
  system_prompt = f"""
121
  <|system|>
122
  Answer concisely and precisely, You are an assistant who provides concise factual answers.
 
125
  {doc_text}
126
  Question:
127
  """.strip()
 
 
128
  cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
 
 
129
  cache_id = f"cache_{int(time())}"
 
 
130
  cache_store[cache_id] = {
131
  "cache": cache,
132
  "origin_len": origin_len,
133
  "doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
134
  }
 
 
135
  os.unlink(temp_file_path)
 
136
  t2 = time()
 
137
  return {
138
  "cache_id": cache_id,
139
  "message": "Document uploaded and cache created successfully",
140
  "doc_preview": cache_store[cache_id]["doc_preview"],
141
  "time_taken": f"{t2 - t1:.4f} seconds"
142
  }
 
143
  except Exception as e:
 
144
  if os.path.exists(temp_file_path):
145
  os.unlink(temp_file_path)
146
  raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
147
 
148
  @app.post("/generate_answer_from_cache/{cache_id}")
149
  async def generate_answer(cache_id: str, request: QueryRequest):
 
150
  t1 = time()
 
 
151
  if cache_id not in cache_store:
152
  raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
 
153
  try:
 
154
  current_cache = clean_up(
155
+ cache_store[cache_id]["cache"],
156
  cache_store[cache_id]["origin_len"]
157
  )
 
 
158
  full_prompt = f"""
159
  <|user|>
160
  Question: {request.query}
161
  <|assistant|>
162
  """.strip()
 
163
  input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
 
 
164
  output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens)
165
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
166
+ rep = clean_response(response)
167
  t2 = time()
 
168
  return {
169
  "query": request.query,
170
  "answer": rep,
171
  "time_taken": f"{t2 - t1:.4f} seconds"
172
  }
 
173
  except Exception as e:
174
  raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
175
 
176
  @app.post("/save_cache/{cache_id}")
177
  async def save_cache(cache_id: str):
 
178
  if cache_id not in cache_store:
179
  raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
 
180
  try:
 
181
  cleaned_cache = clean_up(
182
+ cache_store[cache_id]["cache"],
183
  cache_store[cache_id]["origin_len"]
184
  )
 
185
  cache_path = f"{cache_id}_cache.pth"
186
  torch.save(cleaned_cache, cache_path)
 
187
  return {
188
  "message": f"Cache saved successfully as {cache_path}",
189
  "cache_path": cache_path
190
  }
 
191
  except Exception as e:
192
  raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}")
193
 
194
  @app.post("/load_cache")
195
  async def load_cache(file: UploadFile = File(...)):
 
196
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file:
197
  temp_file_path = temp_file.name
198
  content = await file.read()
199
  temp_file.write(content)
 
200
  try:
 
201
  loaded_cache = torch.load(temp_file_path)
 
 
202
  cache_id = f"loaded_cache_{int(time())}"
 
 
203
  cache_store[cache_id] = {
204
  "cache": loaded_cache,
205
  "origin_len": loaded_cache.key_cache[0].shape[-2],
206
  "doc_preview": "Loaded from cache file"
207
  }
 
 
208
  os.unlink(temp_file_path)
 
209
  return {
210
  "cache_id": cache_id,
211
  "message": "Cache loaded successfully"
212
  }
 
213
  except Exception as e:
 
214
  if os.path.exists(temp_file_path):
215
  os.unlink(temp_file_path)
216
  raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}")
217
 
218
  @app.get("/list_of_caches")
219
  async def list_documents():
 
220
  documents = {}
221
  for cache_id in cache_store:
222
  documents[cache_id] = {
223
  "doc_preview": cache_store[cache_id]["doc_preview"],
224
  "origin_len": cache_store[cache_id]["origin_len"]
225
  }
 
226
  return {"documents": documents}
227
 
228
  @app.get("/", include_in_schema=False)
 
230
  return RedirectResponse(url="/docs")
231
 
232
  if __name__ == "__main__":
233
+ uvicorn.run(app, host="0.0.0.0", port=7860)