Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Body | |
| from fastapi.responses import JSONResponse | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers.cache_utils import DynamicCache , StaticCache | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import uvicorn | |
| import tempfile | |
| from time import time | |
| # Add necessary serialization safety | |
| torch.serialization.add_safe_globals([DynamicCache]) | |
| torch.serialization.add_safe_globals([set]) | |
| #These lines allow PyTorch to serialize and deserialize these objects without raising errors, | |
| # #ensuring compatibility and functionality during cache saving/loading. | |
| # Minimal generate function for token-by-token generation | |
| def generate(model, | |
| input_ids, | |
| past_key_values, | |
| max_new_tokens=50): | |
| """ | |
| This function performs token-by-token text generation using a pre-trained language model. | |
| Purpose: To generate new text based on input tokens, without loading the full context repeatedly | |
| Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum | |
| Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens | |
| """ | |
| device = model.model.embed_tokens.weight.device | |
| origin_len = input_ids.shape[-1]#Stores the length of the input sequence (number of tokens) before text generation begins./return only the newly | |
| input_ids = input_ids.to(device)#same device as the model. | |
| output_ids = input_ids.clone()#will be updated during the generation process to include newly generated tokens. | |
| next_token = input_ids#the token that will process in the next iteration. | |
| with torch.no_grad(): | |
| for _ in range(max_new_tokens): | |
| out = model( | |
| input_ids=next_token, | |
| past_key_values=past_key_values, | |
| use_cache=True | |
| ) | |
| logits = out.logits[:, -1, :]#Extracts the logits for the last token | |
| token = torch.argmax(logits, dim=-1, keepdim=True)#highest predicted probability as the next token. | |
| output_ids = torch.cat([output_ids, token], dim=-1)#add the newly generated token | |
| past_key_values = out.past_key_values | |
| next_token = token.to(device) | |
| if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: | |
| break | |
| return output_ids[:, origin_len:] # Return just the newly generated part | |
| def get_kv_cache(model, tokenizer, prompt): | |
| """ | |
| This function creates a key-value cache for a given prompt. | |
| Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt | |
| Process: Encodes the prompt, runs it through the model, and captures the resulting cache | |
| Returns: The cache object and the original prompt length for future reference | |
| """ | |
| # Encode prompt | |
| device = model.model.embed_tokens.weight.device | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| cache = DynamicCache() # it grows as text is generated | |
| # Run the model to populate the KV cache: | |
| with torch.no_grad(): | |
| _ = model( | |
| input_ids=input_ids, | |
| past_key_values=cache, | |
| use_cache=True | |
| ) | |
| return cache, input_ids.shape[-1] | |
| def clean_up(cache, origin_len): | |
| # Make a deep copy of the cache first | |
| new_cache = DynamicCache() | |
| for i in range(len(cache.key_cache)): | |
| new_cache.key_cache.append(cache.key_cache[i].clone()) | |
| new_cache.value_cache.append(cache.value_cache[i].clone()) | |
| # Remove any tokens appended to the original knowledge | |
| for i in range(len(new_cache.key_cache)): | |
| new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :] | |
| new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :] | |
| return new_cache | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| os.environ["HF_HUB_OFFLINE"] = "1" | |
| # Path to your local model | |
| # Initialize model and tokenizer | |
| def load_model_and_tokenizer(): | |
| model_path = "./deepseek" | |
| # Load tokenizer and model from disk (without trust_remote_code) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| if torch.cuda.is_available(): | |
| # Load model on GPU if CUDA is available | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto" # Automatically map model layers to GPU | |
| ) | |
| else: | |
| # Load model on CPU if no GPU is available | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float32, # Use float32 for compatibility with CPU | |
| low_cpu_mem_usage=True # Reduce memory usage on CPU | |
| ) | |
| return model, tokenizer | |
| # Create FastAPI app | |
| app = FastAPI(title="DeepSeek QA with KV Cache API") | |
| # Global variables to store the cache, origin length, and model/tokenizer | |
| cache_store = {} | |
| # Initialize model and tokenizer at startup | |
| model, tokenizer = load_model_and_tokenizer() | |
| class QueryRequest(BaseModel): | |
| query: str | |
| max_new_tokens: Optional[int] = 150 | |
| def clean_response(response_text): | |
| """ | |
| Clean up model response by removing redundant tags, repetitions, and formatting issues. | |
| """ | |
| # First, try to extract just the answer content between tags if they exist | |
| import re | |
| # Try to extract content between assistant tags if present | |
| assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL) | |
| matches = assistant_pattern.findall(response_text) | |
| if matches: | |
| # Return the first meaningful assistant response | |
| for match in matches: | |
| cleaned = match.strip() | |
| if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5: | |
| return cleaned | |
| # If no proper match found, try more aggressive cleaning | |
| # Remove all tag markers completely | |
| cleaned = re.sub(r'<\|.*?\|>', '', response_text) | |
| cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned) | |
| # Remove duplicate lines (common in generated responses) | |
| lines = cleaned.strip().split('\n') | |
| unique_lines = [] | |
| for line in lines: | |
| line = line.strip() | |
| if line and line not in unique_lines: | |
| unique_lines.append(line) | |
| result = '\n'.join(unique_lines) | |
| # Final cleanup - remove any trailing system/user markers | |
| result = re.sub(r'<\/?\|.*?\|>\s*$', '', result) | |
| return result.strip() | |
| async def upload_document(file: UploadFile = File(...)): | |
| """Upload a document and create KV cache for it""" | |
| t1 = time() | |
| # Save the uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file: | |
| temp_file_path = temp_file.name | |
| content = await file.read() | |
| temp_file.write(content) | |
| try: | |
| # Read the document | |
| with open(temp_file_path, "r", encoding="utf-8") as f: | |
| doc_text = f.read() | |
| # Create system prompt with document context | |
| system_prompt = f""" | |
| <|system|> | |
| Answer concisely and precisely, You are an assistant who provides concise factual answers. | |
| <|user|> | |
| Context: | |
| {doc_text} | |
| Question: | |
| """.strip() | |
| # Create KV cache | |
| cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) | |
| # Generate a unique ID for this document/cache | |
| cache_id = f"cache_{int(time())}" | |
| # Store the cache and origin_len | |
| cache_store[cache_id] = { | |
| "cache": cache, | |
| "origin_len": origin_len, | |
| "doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text | |
| } | |
| # Clean up the temporary file | |
| os.unlink(temp_file_path) | |
| t2 = time() | |
| return { | |
| "cache_id": cache_id, | |
| "message": "Document uploaded and cache created successfully", | |
| "doc_preview": cache_store[cache_id]["doc_preview"], | |
| "time_taken": f"{t2 - t1:.4f} seconds" | |
| } | |
| except Exception as e: | |
| # Clean up the temporary file in case of error | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}") | |
| async def generate_answer(cache_id: str, request: QueryRequest): | |
| """Generate an answer to a question based on the uploaded document""" | |
| t1 = time() | |
| # Check if the document/cache exists | |
| if cache_id not in cache_store: | |
| raise HTTPException(status_code=404, detail="Document not found. Please upload it first.") | |
| try: | |
| # Get a clean copy of the cache | |
| current_cache = clean_up( | |
| cache_store[cache_id]["cache"], | |
| cache_store[cache_id]["origin_len"] | |
| ) | |
| # Prepare input with just the query | |
| full_prompt = f""" | |
| <|user|> | |
| Question: {request.query} | |
| <|assistant|> | |
| """.strip() | |
| input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids | |
| # Generate response | |
| output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens) | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| rep = clean_response(response) | |
| t2 = time() | |
| return { | |
| "query": request.query, | |
| "answer": rep, | |
| "time_taken": f"{t2 - t1:.4f} seconds" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}") | |
| async def save_cache(cache_id: str): | |
| """Save the cache for a document""" | |
| if cache_id not in cache_store: | |
| raise HTTPException(status_code=404, detail="Document not found. Please upload it first.") | |
| try: | |
| # Clean up the cache and save it | |
| cleaned_cache = clean_up( | |
| cache_store[cache_id]["cache"], | |
| cache_store[cache_id]["origin_len"] | |
| ) | |
| cache_path = f"{cache_id}_cache.pth" | |
| torch.save(cleaned_cache, cache_path) | |
| return { | |
| "message": f"Cache saved successfully as {cache_path}", | |
| "cache_path": cache_path | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}") | |
| async def load_cache(file: UploadFile = File(...)): | |
| """Load a previously saved cache""" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file: | |
| temp_file_path = temp_file.name | |
| content = await file.read() | |
| temp_file.write(content) | |
| try: | |
| # Load the cache | |
| loaded_cache = torch.load(temp_file_path) | |
| # Generate a unique ID for this cache | |
| cache_id = f"loaded_cache_{int(time())}" | |
| # Store the cache (we don't have the original document text) | |
| cache_store[cache_id] = { | |
| "cache": loaded_cache, | |
| "origin_len": loaded_cache.key_cache[0].shape[-2], | |
| "doc_preview": "Loaded from cache file" | |
| } | |
| # Clean up the temporary file | |
| os.unlink(temp_file_path) | |
| return { | |
| "cache_id": cache_id, | |
| "message": "Cache loaded successfully" | |
| } | |
| except Exception as e: | |
| # Clean up the temporary file in case of error | |
| if os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}") | |
| async def list_documents(): | |
| """List all uploaded documents/caches""" | |
| documents = {} | |
| for cache_id in cache_store: | |
| documents[cache_id] = { | |
| "doc_preview": cache_store[cache_id]["doc_preview"], | |
| "origin_len": cache_store[cache_id]["origin_len"] | |
| } | |
| return {"documents": documents} | |
| async def root(): | |
| return {"message": "DeepSeek QA with KV Cache API is running"} | |
| if __name__ == "__main__": | |
| # Run the FastAPI app | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |