kouki321 commited on
Commit
93da4bd
·
verified ·
1 Parent(s): 4283e53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +328 -26
app.py CHANGED
@@ -1,34 +1,336 @@
1
  import os
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BertForMaskedLM
3
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
8
- model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Model identifier
11
- model_id = "sshleifer/tiny-gpt2"
12
- #"google/flan-t5-small"
13
- #"unsloth/mistral-7b-v0.2-bnb-4bit"
14
- #deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
15
-
16
- #model_id = "remi/bertabs-finetuned-extractive-abstractive-summarization"
17
- import shutil
18
- total, used, free = shutil.disk_usage("/")
19
- print("Total: %.2f GB" % (total / (2**30)))
20
- print("Used: %.2f GB" % (used / (2**30)))
21
- print("Free: %.2f GB" % (free / (2**30)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- tokenizer = AutoTokenizer.from_pretrained(model_id)
24
- model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
25
 
26
- def generate(text: str) -> str:
27
- inputs = tokenizer(text, return_tensors="pt")
28
- outputs = model.generate(**inputs, max_new_tokens=50)
29
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
30
 
31
  if __name__ == "__main__":
32
- prompt = "The capital of France is"
33
- result = generate(prompt)
34
- print(result)
 
1
  import os
2
+ import torch
3
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Body
4
+ from fastapi.responses import JSONResponse
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from transformers.cache_utils import DynamicCache , StaticCache
7
+ from pydantic import BaseModel
8
+ from typing import Optional
9
+ import uvicorn
10
+ import tempfile
11
+ from time import time
12
+
13
+
14
+ # Add necessary serialization safety
15
+ torch.serialization.add_safe_globals([DynamicCache])
16
+ torch.serialization.add_safe_globals([set])
17
+ #These lines allow PyTorch to serialize and deserialize these objects without raising errors,
18
+ # #ensuring compatibility and functionality during cache saving/loading.
19
+
20
+ # Minimal generate function for token-by-token generation
21
+ def generate(model,
22
+ input_ids,
23
+ past_key_values,
24
+ max_new_tokens=50):
25
+ """
26
+ This function performs token-by-token text generation using a pre-trained language model.
27
+ Purpose: To generate new text based on input tokens, without loading the full context repeatedly
28
+ Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum
29
+ Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens
30
+ """
31
+ device = model.model.embed_tokens.weight.device
32
+ origin_len = input_ids.shape[-1]#Stores the length of the input sequence (number of tokens) before text generation begins./return only the newly
33
+ input_ids = input_ids.to(device)#same device as the model.
34
+ output_ids = input_ids.clone()#will be updated during the generation process to include newly generated tokens.
35
+ next_token = input_ids#the token that will process in the next iteration.
36
+ with torch.no_grad():
37
+ for _ in range(max_new_tokens):
38
+ out = model(
39
+ input_ids=next_token,
40
+ past_key_values=past_key_values,
41
+ use_cache=True
42
+ )
43
+ logits = out.logits[:, -1, :]#Extracts the logits for the last token
44
+ token = torch.argmax(logits, dim=-1, keepdim=True)#highest predicted probability as the next token.
45
+ output_ids = torch.cat([output_ids, token], dim=-1)#add the newly generated token
46
+ past_key_values = out.past_key_values
47
+ next_token = token.to(device)
48
+ if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
49
+ break
50
+ return output_ids[:, origin_len:] # Return just the newly generated part
51
+
52
+ def get_kv_cache(model, tokenizer, prompt):
53
+ """
54
+ This function creates a key-value cache for a given prompt.
55
+ Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt
56
+ Process: Encodes the prompt, runs it through the model, and captures the resulting cache
57
+ Returns: The cache object and the original prompt length for future reference
58
+ """
59
+ # Encode prompt
60
+ device = model.model.embed_tokens.weight.device
61
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
62
+ cache = DynamicCache() # it grows as text is generated
63
+
64
+ # Run the model to populate the KV cache:
65
+ with torch.no_grad():
66
+ _ = model(
67
+ input_ids=input_ids,
68
+ past_key_values=cache,
69
+ use_cache=True
70
+ )
71
+ return cache, input_ids.shape[-1]
72
+
73
+ def clean_up(cache, origin_len):
74
+ # Make a deep copy of the cache first
75
+ new_cache = DynamicCache()
76
+ for i in range(len(cache.key_cache)):
77
+ new_cache.key_cache.append(cache.key_cache[i].clone())
78
+ new_cache.value_cache.append(cache.value_cache[i].clone())
79
+
80
+ # Remove any tokens appended to the original knowledge
81
+ for i in range(len(new_cache.key_cache)):
82
+ new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :]
83
+ new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :]
84
+ return new_cache
85
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
86
+ os.environ["HF_HUB_OFFLINE"] = "1"
87
+
88
+ # Path to your local model
89
+
90
+ # Initialize model and tokenizer
91
+ def load_model_and_tokenizer():
92
+ model_path = "./deepseek"
93
+
94
+ # Load tokenizer and model from disk (without trust_remote_code)
95
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
96
+ if torch.cuda.is_available():
97
+ # Load model on GPU if CUDA is available
98
+ model = AutoModelForCausalLM.from_pretrained(
99
+ model_path,
100
+ torch_dtype=torch.float16,
101
+ device_map="auto" # Automatically map model layers to GPU
102
+ )
103
+ else:
104
+ # Load model on CPU if no GPU is available
105
+ model = AutoModelForCausalLM.from_pretrained(
106
+ model_path,
107
+ torch_dtype=torch.float32, # Use float32 for compatibility with CPU
108
+ low_cpu_mem_usage=True # Reduce memory usage on CPU
109
+ )
110
+ return model, tokenizer
111
+
112
+ # Create FastAPI app
113
+ app = FastAPI(title="DeepSeek QA with KV Cache API")
114
+
115
+ # Global variables to store the cache, origin length, and model/tokenizer
116
+ cache_store = {}
117
+
118
+ # Initialize model and tokenizer at startup
119
+ model, tokenizer = load_model_and_tokenizer()
120
+
121
+ class QueryRequest(BaseModel):
122
+ query: str
123
+ max_new_tokens: Optional[int] = 150
124
+ def clean_response(response_text):
125
+ """
126
+ Clean up model response by removing redundant tags, repetitions, and formatting issues.
127
+ """
128
+ # First, try to extract just the answer content between tags if they exist
129
+ import re
130
+
131
+ # Try to extract content between assistant tags if present
132
+ assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL)
133
+ matches = assistant_pattern.findall(response_text)
134
+
135
+ if matches:
136
+ # Return the first meaningful assistant response
137
+ for match in matches:
138
+ cleaned = match.strip()
139
+ if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5:
140
+ return cleaned
141
+
142
+ # If no proper match found, try more aggressive cleaning
143
+ # Remove all tag markers completely
144
+ cleaned = re.sub(r'<\|.*?\|>', '', response_text)
145
+ cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned)
146
+
147
+ # Remove duplicate lines (common in generated responses)
148
+ lines = cleaned.strip().split('\n')
149
+ unique_lines = []
150
+ for line in lines:
151
+ line = line.strip()
152
+ if line and line not in unique_lines:
153
+ unique_lines.append(line)
154
+
155
+ result = '\n'.join(unique_lines)
156
+
157
+ # Final cleanup - remove any trailing system/user markers
158
+ result = re.sub(r'<\/?\|.*?\|>\s*$', '', result)
159
+
160
+ return result.strip()
161
+ @app.post("/upload-document_to_create_KV_cache")
162
+ async def upload_document(file: UploadFile = File(...)):
163
+ """Upload a document and create KV cache for it"""
164
+ t1 = time()
165
+
166
+ # Save the uploaded file temporarily
167
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
168
+ temp_file_path = temp_file.name
169
+ content = await file.read()
170
+ temp_file.write(content)
171
+
172
+ try:
173
+ # Read the document
174
+ with open(temp_file_path, "r", encoding="utf-8") as f:
175
+ doc_text = f.read()
176
+
177
+ # Create system prompt with document context
178
+ system_prompt = f"""
179
+ <|system|>
180
+ Answer concisely and precisely, You are an assistant who provides concise factual answers.
181
+ <|user|>
182
+ Context:
183
+ {doc_text}
184
+ Question:
185
+ """.strip()
186
+
187
+ # Create KV cache
188
+ cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
189
+
190
+ # Generate a unique ID for this document/cache
191
+ cache_id = f"cache_{int(time())}"
192
+
193
+ # Store the cache and origin_len
194
+ cache_store[cache_id] = {
195
+ "cache": cache,
196
+ "origin_len": origin_len,
197
+ "doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
198
+ }
199
+
200
+ # Clean up the temporary file
201
+ os.unlink(temp_file_path)
202
+
203
+ t2 = time()
204
+
205
+ return {
206
+ "cache_id": cache_id,
207
+ "message": "Document uploaded and cache created successfully",
208
+ "doc_preview": cache_store[cache_id]["doc_preview"],
209
+ "time_taken": f"{t2 - t1:.4f} seconds"
210
+ }
211
+
212
+ except Exception as e:
213
+ # Clean up the temporary file in case of error
214
+ if os.path.exists(temp_file_path):
215
+ os.unlink(temp_file_path)
216
+ raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
217
 
218
+ @app.post("/generate_answer_from_cache/{cache_id}")
219
+ async def generate_answer(cache_id: str, request: QueryRequest):
220
+ """Generate an answer to a question based on the uploaded document"""
221
+ t1 = time()
222
+
223
+ # Check if the document/cache exists
224
+ if cache_id not in cache_store:
225
+ raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
226
+
227
+ try:
228
+ # Get a clean copy of the cache
229
+ current_cache = clean_up(
230
+ cache_store[cache_id]["cache"],
231
+ cache_store[cache_id]["origin_len"]
232
+ )
233
+
234
+ # Prepare input with just the query
235
+ full_prompt = f"""
236
+ <|user|>
237
+ Question: {request.query}
238
+ <|assistant|>
239
+ """.strip()
240
+
241
+ input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
242
+
243
+ # Generate response
244
+ output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens)
245
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
246
+ rep = clean_response(response)
247
+ t2 = time()
248
+
249
+ return {
250
+ "query": request.query,
251
+ "answer": rep,
252
+ "time_taken": f"{t2 - t1:.4f} seconds"
253
+ }
254
+
255
+ except Exception as e:
256
+ raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
257
 
258
+ @app.post("/save_cache/{cache_id}")
259
+ async def save_cache(cache_id: str):
260
+ """Save the cache for a document"""
261
+ if cache_id not in cache_store:
262
+ raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
263
+
264
+ try:
265
+ # Clean up the cache and save it
266
+ cleaned_cache = clean_up(
267
+ cache_store[cache_id]["cache"],
268
+ cache_store[cache_id]["origin_len"]
269
+ )
270
+
271
+ cache_path = f"{cache_id}_cache.pth"
272
+ torch.save(cleaned_cache, cache_path)
273
+
274
+ return {
275
+ "message": f"Cache saved successfully as {cache_path}",
276
+ "cache_path": cache_path
277
+ }
278
+
279
+ except Exception as e:
280
+ raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}")
281
 
282
+ @app.post("/load_cache")
283
+ async def load_cache(file: UploadFile = File(...)):
284
+ """Load a previously saved cache"""
285
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file:
286
+ temp_file_path = temp_file.name
287
+ content = await file.read()
288
+ temp_file.write(content)
289
+
290
+ try:
291
+ # Load the cache
292
+ loaded_cache = torch.load(temp_file_path)
293
+
294
+ # Generate a unique ID for this cache
295
+ cache_id = f"loaded_cache_{int(time())}"
296
+
297
+ # Store the cache (we don't have the original document text)
298
+ cache_store[cache_id] = {
299
+ "cache": loaded_cache,
300
+ "origin_len": loaded_cache.key_cache[0].shape[-2],
301
+ "doc_preview": "Loaded from cache file"
302
+ }
303
+
304
+ # Clean up the temporary file
305
+ os.unlink(temp_file_path)
306
+
307
+ return {
308
+ "cache_id": cache_id,
309
+ "message": "Cache loaded successfully"
310
+ }
311
+
312
+ except Exception as e:
313
+ # Clean up the temporary file in case of error
314
+ if os.path.exists(temp_file_path):
315
+ os.unlink(temp_file_path)
316
+ raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}")
317
 
318
+ @app.get("/list_of_caches")
319
+ async def list_documents():
320
+ """List all uploaded documents/caches"""
321
+ documents = {}
322
+ for cache_id in cache_store:
323
+ documents[cache_id] = {
324
+ "doc_preview": cache_store[cache_id]["doc_preview"],
325
+ "origin_len": cache_store[cache_id]["origin_len"]
326
+ }
327
+
328
+ return {"documents": documents}
329
 
330
+ @app.get("/")
331
+ async def root():
332
+ return {"message": "DeepSeek QA with KV Cache API is running"}
 
333
 
334
  if __name__ == "__main__":
335
+ # Run the FastAPI app
336
+ uvicorn.run(app, host="0.0.0.0", port=7860)