| | import os |
| | import torch |
| | from fastapi import FastAPI, File, UploadFile, HTTPException, Body |
| | from fastapi.responses import JSONResponse, RedirectResponse |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from transformers.cache_utils import DynamicCache |
| | from pydantic import BaseModel |
| | from typing import Optional |
| | import tempfile |
| | from time import time |
| | import uvicorn |
| | |
| |
|
| | |
| | os.environ["HF_HOME"] = "/app/hf_cache" |
| | |
| |
|
| | |
| | def get_embed_device(model): |
| | if hasattr(model, "model") and hasattr(model.model, "embed_tokens"): |
| | return model.model.embed_tokens.weight.device |
| | elif hasattr(model, "base_model") and hasattr(model.base_model, "model") and hasattr(model.base_model.model, "embed_tokens"): |
| | return model.base_model.model.embed_tokens.weight.device |
| | elif hasattr(model, "decoder") and hasattr(model.decoder, "embed_tokens"): |
| | return model.decoder.embed_tokens.weight.device |
| | elif hasattr(model, "embed_tokens"): |
| | return model.embed_tokens.weight.device |
| | else: |
| | return next(model.parameters()).device |
| |
|
| | torch.serialization.add_safe_globals([DynamicCache]) |
| | torch.serialization.add_safe_globals([set]) |
| |
|
| | def generate(model, input_ids, past_key_values, max_new_tokens=50): |
| | device = get_embed_device(model) |
| | origin_len = input_ids.shape[-1] |
| | input_ids = input_ids.to(device) |
| | output_ids = input_ids.clone() |
| | next_token = input_ids |
| | with torch.no_grad(): |
| | for _ in range(max_new_tokens): |
| | out = model( |
| | input_ids=next_token, |
| | past_key_values=past_key_values, |
| | use_cache=True |
| | ) |
| | logits = out.logits[:, -1, :] |
| | token = torch.argmax(logits, dim=-1, keepdim=True) |
| | output_ids = torch.cat([output_ids, token], dim=-1) |
| | past_key_values = out.past_key_values |
| | next_token = token.to(device) |
| | if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: |
| | break |
| | return output_ids[:, origin_len:] |
| |
|
| | def get_kv_cache(model, tokenizer, prompt): |
| | device = get_embed_device(model) |
| | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
| | cache = DynamicCache() |
| | with torch.no_grad(): |
| | _ = model( |
| | input_ids=input_ids, |
| | past_key_values=cache, |
| | use_cache=True |
| | ) |
| | return cache, input_ids.shape[-1] |
| |
|
| | def clean_up(cache, origin_len): |
| | new_cache = DynamicCache() |
| | for i in range(len(cache.key_cache)): |
| | new_cache.key_cache.append(cache.key_cache[i].clone()) |
| | new_cache.value_cache.append(cache.value_cache[i].clone()) |
| | for i in range(len(new_cache.key_cache)): |
| | new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :] |
| | new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :] |
| | return new_cache |
| |
|
| | def clean_response(response_text): |
| | import re |
| | assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL) |
| | matches = assistant_pattern.findall(response_text) |
| | if matches: |
| | for match in matches: |
| | cleaned = match.strip() |
| | if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5: |
| | return cleaned |
| | cleaned = re.sub(r'<\|.*?\|>', '', response_text) |
| | cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned) |
| | lines = cleaned.strip().split('\n') |
| | unique_lines = [] |
| | for line in lines: |
| | line = line.strip() |
| | if line and line not in unique_lines: |
| | unique_lines.append(line) |
| | result = '\n'.join(unique_lines) |
| | result = re.sub(r'<\/?\|.*?\|>\s*$', '', result) |
| | return result.strip() |
| |
|
| | |
| | app = FastAPI(title="DeepSeek QA with KV Cache API") |
| | cache_store = {} |
| |
|
| | |
| | model_id ="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
| | tokenizer = AutoTokenizer.from_pretrained(model_id,trust_remote_code=True) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.float32, |
| | low_cpu_mem_usage=True,trust_remote_code=True |
| | ) |
| |
|
| | class QueryRequest(BaseModel): |
| | query: str |
| | max_new_tokens: Optional[int] = 150 |
| |
|
| | @app.post("/upload-document_to_create_KV_cache") |
| | async def upload_document(file: UploadFile = File(...)): |
| | t1 = time() |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file: |
| | temp_file_path = temp_file.name |
| | content = await file.read() |
| | temp_file.write(content) |
| | try: |
| | with open(temp_file_path, "r", encoding="utf-8") as f: |
| | doc_text = f.read() |
| | system_prompt = f""" |
| | <|system|> |
| | Answer concisely and precisely, You are an assistant who provides concise factual answers. |
| | <|user|> |
| | Context: |
| | {doc_text} |
| | Question: |
| | """.strip() |
| | cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) |
| | cache_id = f"cache_{int(time())}" |
| | cache_store[cache_id] = { |
| | "cache": cache, |
| | "origin_len": origin_len, |
| | "doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text |
| | } |
| | os.unlink(temp_file_path) |
| | t2 = time() |
| | return { |
| | "cache_id": cache_id, |
| | "message": "Document uploaded and cache created successfully", |
| | "doc_preview": cache_store[cache_id]["doc_preview"], |
| | "time_taken": f"{t2 - t1:.4f} seconds" |
| | } |
| | except Exception as e: |
| | if os.path.exists(temp_file_path): |
| | os.unlink(temp_file_path) |
| | raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}") |
| |
|
| | @app.post("/generate_answer_from_cache/{cache_id}") |
| | async def generate_answer(cache_id: str, request: QueryRequest): |
| | t1 = time() |
| | if cache_id not in cache_store: |
| | raise HTTPException(status_code=404, detail="Document not found. Please upload it first.") |
| | try: |
| | current_cache = clean_up( |
| | cache_store[cache_id]["cache"], |
| | cache_store[cache_id]["origin_len"] |
| | ) |
| | full_prompt = f""" |
| | <|user|> |
| | Question: {request.query} |
| | <|assistant|> |
| | """.strip() |
| | input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids |
| | output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens) |
| | response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| | rep = clean_response(response) |
| | t2 = time() |
| | return { |
| | "query": request.query, |
| | "answer": rep, |
| | "time_taken": f"{t2 - t1:.4f} seconds" |
| | } |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}") |
| |
|
| | @app.post("/save_cache/{cache_id}") |
| | async def save_cache(cache_id: str): |
| | if cache_id not in cache_store: |
| | raise HTTPException(status_code=404, detail="Document not found. Please upload it first.") |
| | try: |
| | cleaned_cache = clean_up( |
| | cache_store[cache_id]["cache"], |
| | cache_store[cache_id]["origin_len"] |
| | ) |
| | cache_path = f"{cache_id}_cache.pth" |
| | torch.save(cleaned_cache, cache_path) |
| | return { |
| | "message": f"Cache saved successfully as {cache_path}", |
| | "cache_path": cache_path |
| | } |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}") |
| |
|
| | @app.post("/load_cache") |
| | async def load_cache(file: UploadFile = File(...)): |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file: |
| | temp_file_path = temp_file.name |
| | content = await file.read() |
| | temp_file.write(content) |
| | try: |
| | loaded_cache = torch.load(temp_file_path) |
| | cache_id = f"loaded_cache_{int(time())}" |
| | cache_store[cache_id] = { |
| | "cache": loaded_cache, |
| | "origin_len": loaded_cache.key_cache[0].shape[-2], |
| | "doc_preview": "Loaded from cache file" |
| | } |
| | os.unlink(temp_file_path) |
| | return { |
| | "cache_id": cache_id, |
| | "message": "Cache loaded successfully" |
| | } |
| | except Exception as e: |
| | if os.path.exists(temp_file_path): |
| | os.unlink(temp_file_path) |
| | raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}") |
| |
|
| | @app.get("/list_of_caches") |
| | async def list_documents(): |
| | documents = {} |
| | for cache_id in cache_store: |
| | documents[cache_id] = { |
| | "doc_preview": cache_store[cache_id]["doc_preview"], |
| | "origin_len": cache_store[cache_id]["origin_len"] |
| | } |
| | return {"documents": documents} |
| |
|
| | @app.get("/", include_in_schema=False) |
| | async def root(): |
| | return RedirectResponse(url="/docs") |
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |