Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import shutil | |
| import hashlib | |
| import pickle | |
| import sqlite3 | |
| import json | |
| import time | |
| import threading | |
| from typing import List, Any | |
| from pathlib import Path | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import TextLoader, PyPDFLoader | |
| from langchain.docstore.document import Document | |
| from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_openai import ChatOpenAI | |
| from fastapi import FastAPI, UploadFile, Form | |
| from fastapi.responses import JSONResponse | |
| import requests | |
| # ------------------------- | |
| # Configuration | |
| # ------------------------- | |
| CACHE_DIR = Path(tempfile.gettempdir()) / "budgetbot_cache" | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| DB_PATH = CACHE_DIR / "cache.db" | |
| # File size limits (in bytes) | |
| MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB | |
| MIN_TEXT_LENGTH = 100 # Minimum text length after extraction | |
| # ------------------------- | |
| # SQLite Cache | |
| # ------------------------- | |
| class SimpleCache: | |
| def __init__(self, db_path=DB_PATH): | |
| self.conn = sqlite3.connect(str(db_path), check_same_thread=False) | |
| self.lock = threading.Lock() | |
| self.conn.execute(''' | |
| CREATE TABLE IF NOT EXISTS cache ( | |
| key TEXT PRIMARY KEY, | |
| value TEXT, | |
| timestamp REAL | |
| ) | |
| ''') | |
| self.conn.execute(''' | |
| CREATE TABLE IF NOT EXISTS vectorstore_cache ( | |
| file_hash TEXT PRIMARY KEY, | |
| data BLOB, | |
| timestamp REAL | |
| ) | |
| ''') | |
| self.conn.commit() | |
| def get(self, key): | |
| with self.lock: | |
| cursor = self.conn.execute( | |
| 'SELECT value FROM cache WHERE key = ?', (key,) | |
| ) | |
| row = cursor.fetchone() | |
| return json.loads(row[0]) if row else None | |
| def set(self, key, value, ttl=3600): | |
| with self.lock: | |
| self.conn.execute( | |
| 'INSERT OR REPLACE INTO cache VALUES (?, ?, ?)', | |
| (key, json.dumps(value), time.time()) | |
| ) | |
| self.conn.commit() | |
| def get_vectorstore(self, file_hash): | |
| with self.lock: | |
| cursor = self.conn.execute( | |
| 'SELECT data FROM vectorstore_cache WHERE file_hash = ?', (file_hash,) | |
| ) | |
| row = cursor.fetchone() | |
| if row: | |
| try: | |
| return pickle.loads(row[0]) | |
| except Exception as e: | |
| print(f"Failed to load vectorstore from cache: {e}") | |
| return None | |
| return None | |
| def set_vectorstore(self, file_hash, vectorstore): | |
| with self.lock: | |
| try: | |
| data = pickle.dumps(vectorstore) | |
| self.conn.execute( | |
| 'INSERT OR REPLACE INTO vectorstore_cache VALUES (?, ?, ?)', | |
| (file_hash, data, time.time()) | |
| ) | |
| self.conn.commit() | |
| except Exception as e: | |
| print(f"Cache error: {e}") | |
| def cleanup_old(self, max_age=86400): | |
| """Remove cache entries older than max_age seconds""" | |
| with self.lock: | |
| cutoff = time.time() - max_age | |
| self.conn.execute('DELETE FROM cache WHERE timestamp < ?', (cutoff,)) | |
| self.conn.execute('DELETE FROM vectorstore_cache WHERE timestamp < ?', (cutoff,)) | |
| self.conn.commit() | |
| # Initialize cache | |
| cache = SimpleCache() | |
| # ------------------------- | |
| # Embedding Model (Cached) | |
| # ------------------------- | |
| _embeddings_lock = threading.Lock() | |
| _embeddings_cache = None | |
| def get_embeddings(): | |
| global _embeddings_cache | |
| if _embeddings_cache is None: | |
| with _embeddings_lock: | |
| if _embeddings_cache is None: | |
| _embeddings_cache = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L12-v2", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| return _embeddings_cache | |
| # ------------------------- | |
| # File processing | |
| # ------------------------- | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, | |
| chunk_overlap=100, | |
| separators=["\n\n", "\n", ". ", " ", ""], | |
| length_function=len, | |
| ) | |
| def get_file_hash(file_path: str) -> str: | |
| """Generate SHA256 hash of file""" | |
| sha256 = hashlib.sha256() | |
| with open(file_path, 'rb') as f: | |
| for chunk in iter(lambda: f.read(8192), b''): | |
| sha256.update(chunk) | |
| return sha256.hexdigest() | |
| def process_file_path(file_path: str) -> List[Document]: | |
| """Load and split PDF or TXT into LangChain Documents.""" | |
| try: | |
| if file_path.lower().endswith(".pdf"): | |
| loader = PyPDFLoader(file_path) | |
| else: | |
| loader = TextLoader(file_path, encoding='utf-8') | |
| docs = loader.load() | |
| # Check if documents have any text content | |
| total_text = "".join([doc.page_content for doc in docs]).strip() | |
| if not total_text or len(total_text) < MIN_TEXT_LENGTH: | |
| raise ValueError( | |
| f"Insufficient text extracted from file. " | |
| f"Got {len(total_text)} characters. " | |
| f"This PDF may contain only images or scanned content. " | |
| f"Please provide a text-based PDF or use OCR." | |
| ) | |
| # Split documents | |
| split_docs = text_splitter.split_documents(docs) | |
| # Filter out empty chunks | |
| split_docs = [doc for doc in split_docs if doc.page_content.strip()] | |
| if not split_docs: | |
| raise ValueError("No valid text chunks after processing") | |
| print(f"Processed {len(split_docs)} text chunks from file") | |
| return split_docs | |
| except Exception as e: | |
| print(f"Error processing file: {e}") | |
| raise | |
| # ------------------------- | |
| # LLM (Cached) | |
| # ------------------------- | |
| _llm_cache = None | |
| _llm_lock = threading.Lock() | |
| def get_llm(): | |
| global _llm_cache | |
| if _llm_cache is None: | |
| with _llm_lock: | |
| if _llm_cache is None: | |
| _llm_cache = ChatOpenAI( | |
| model="qwen/qwen-2.5-7b-instruct", | |
| streaming=True, | |
| temperature=0, | |
| max_tokens=512, | |
| openai_api_base=os.environ.get("OPENAI_API_BASE", "https://openrouter.ai/api/v1"), | |
| openai_api_key=os.environ.get("OPENROUTER_API_KEY") | |
| ) | |
| return _llm_cache | |
| # ------------------------- | |
| # Retrieval QA Pipeline | |
| # ------------------------- | |
| class RetrievalAugmentedQAPipeline: | |
| def __init__(self, llm: Any, vectorstore: FAISS) -> None: | |
| self.llm = llm | |
| self.vectorstore = vectorstore | |
| system_template = ( | |
| "You are a helpful assistant. " | |
| "Use the following context to answer a user's question. " | |
| "If the context does not contain the answer, reply with 'I don't know'." | |
| ) | |
| self.prompt = ChatPromptTemplate.from_messages([ | |
| SystemMessagePromptTemplate.from_template(system_template), | |
| HumanMessagePromptTemplate.from_template("Context:\n{context}\n\nQuestion:\n{question}") | |
| ]) | |
| def _get_cache_key(self, user_query: str) -> str: | |
| """Generate cache key for query""" | |
| return f"qa_{hashlib.md5(user_query.encode()).hexdigest()}" | |
| async def arun_pipeline(self, user_query: str): | |
| # Check cache first | |
| cache_key = self._get_cache_key(user_query) | |
| cached_response = cache.get(cache_key) | |
| if cached_response: | |
| async def cached_generator(): | |
| yield cached_response['answer'] | |
| return {"response": cached_generator(), "context": cached_response.get('context', []), "cached": True} | |
| # Retrieve documents | |
| docs = self.vectorstore.similarity_search(user_query, k=4) | |
| context_text = "\n".join([doc.page_content for doc in docs]) | |
| messages = self.prompt.format_messages(context=context_text, question=user_query) | |
| # Generate response | |
| full_response = "" | |
| async def generate_response(): | |
| nonlocal full_response | |
| async for chunk in self.llm.astream(messages): | |
| content = chunk.content if chunk.content else "" | |
| full_response += content | |
| yield content | |
| result = {"response": generate_response(), "context": docs, "cached": False} | |
| # Cache after streaming (in background) | |
| async def cache_after_stream(): | |
| import asyncio | |
| await asyncio.sleep(0.5) | |
| if full_response: | |
| cache.set(cache_key, { | |
| 'answer': full_response, | |
| 'context': [{'page_content': doc.page_content} for doc in docs] | |
| }) | |
| import asyncio | |
| asyncio.create_task(cache_after_stream()) | |
| return result | |
| # ------------------------- | |
| # FastAPI (API Mode) | |
| # ------------------------- | |
| app = FastAPI() | |
| global_pipeline = None | |
| current_file_hash = None | |
| async def upload_file(file: UploadFile): | |
| global global_pipeline, current_file_hash | |
| try: | |
| if not file or not file.filename: | |
| return JSONResponse({"error": "No file provided"}, status_code=400) | |
| # Check file extension | |
| if not (file.filename.lower().endswith('.pdf') or file.filename.lower().endswith('.txt')): | |
| return JSONResponse( | |
| {"error": "Only PDF and TXT files are supported"}, | |
| status_code=400 | |
| ) | |
| # Read file content | |
| content = await file.read() | |
| # Check file size | |
| if len(content) > MAX_FILE_SIZE: | |
| return JSONResponse( | |
| {"error": f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB"}, | |
| status_code=400 | |
| ) | |
| if len(content) == 0: | |
| return JSONResponse({"error": "Empty file"}, status_code=400) | |
| # Save uploaded file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.split('.')[-1]}") as tmp: | |
| tmp.write(content) | |
| tmp_path = tmp.name | |
| # Check if file already processed | |
| file_hash = get_file_hash(tmp_path) | |
| # Try to get from cache | |
| vectorstore = cache.get_vectorstore(file_hash) | |
| cached = vectorstore is not None | |
| if vectorstore is None: | |
| # Process file | |
| try: | |
| texts = process_file_path(tmp_path) | |
| if not texts: | |
| os.unlink(tmp_path) | |
| return JSONResponse( | |
| {"error": "No text content found in file. PDF may contain only images."}, | |
| status_code=400 | |
| ) | |
| embeddings = get_embeddings() | |
| vectorstore = FAISS.from_documents(texts, embeddings) | |
| # Cache vectorstore | |
| cache.set_vectorstore(file_hash, vectorstore) | |
| except ValueError as ve: | |
| os.unlink(tmp_path) | |
| return JSONResponse({"error": str(ve)}, status_code=400) | |
| except Exception as e: | |
| os.unlink(tmp_path) | |
| return JSONResponse( | |
| {"error": f"Failed to process file: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| # Create pipeline | |
| chat_llm = get_llm() | |
| global_pipeline = RetrievalAugmentedQAPipeline(llm=chat_llm, vectorstore=vectorstore) | |
| current_file_hash = file_hash | |
| # Cleanup temp file | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| return JSONResponse({ | |
| "status": "File uploaded and processed ✅", | |
| "filename": file.filename, | |
| "cached": cached, | |
| "file_hash": file_hash, | |
| "file_size_mb": round(len(content) / (1024*1024), 2) | |
| }) | |
| except Exception as e: | |
| return JSONResponse( | |
| {"error": f"Upload failed: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| async def upload_file_url(file_url: str = Form(...)): | |
| global global_pipeline, current_file_hash | |
| try: | |
| # Download file from URL | |
| response = requests.get(file_url, stream=True, timeout=60) | |
| if response.status_code != 200: | |
| return JSONResponse( | |
| {"error": f"Failed to download file: {response.status_code}"}, | |
| status_code=400 | |
| ) | |
| filename = file_url.split("/")[-1] or "downloaded_file.pdf" | |
| # Check file extension | |
| if not (filename.lower().endswith('.pdf') or filename.lower().endswith('.txt')): | |
| return JSONResponse( | |
| {"error": "Only PDF and TXT files are supported"}, | |
| status_code=400 | |
| ) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{filename.split('.')[-1]}") as tmp: | |
| total_size = 0 | |
| for chunk in response.iter_content(chunk_size=8192): | |
| total_size += len(chunk) | |
| if total_size > MAX_FILE_SIZE: | |
| os.unlink(tmp.name) | |
| return JSONResponse( | |
| {"error": f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB"}, | |
| status_code=400 | |
| ) | |
| tmp.write(chunk) | |
| tmp_path = tmp.name | |
| # Check if file already processed | |
| file_hash = get_file_hash(tmp_path) | |
| # Try to get from cache | |
| vectorstore = cache.get_vectorstore(file_hash) | |
| cached = vectorstore is not None | |
| if vectorstore is None: | |
| # Process file | |
| try: | |
| texts = process_file_path(tmp_path) | |
| if not texts: | |
| os.unlink(tmp_path) | |
| return JSONResponse( | |
| {"error": "No text content found in file. PDF may contain only images."}, | |
| status_code=400 | |
| ) | |
| embeddings = get_embeddings() | |
| vectorstore = FAISS.from_documents(texts, embeddings) | |
| # Cache vectorstore | |
| cache.set_vectorstore(file_hash, vectorstore) | |
| except ValueError as ve: | |
| os.unlink(tmp_path) | |
| return JSONResponse({"error": str(ve)}, status_code=400) | |
| except Exception as e: | |
| os.unlink(tmp_path) | |
| return JSONResponse( | |
| {"error": f"Failed to process file: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| # Create pipeline | |
| chat_llm = get_llm() | |
| global_pipeline = RetrievalAugmentedQAPipeline(llm=chat_llm, vectorstore=vectorstore) | |
| current_file_hash = file_hash | |
| # Cleanup temp file | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| return JSONResponse({ | |
| "status": "File downloaded and processed ✅", | |
| "filename": filename, | |
| "cached": cached, | |
| "file_hash": file_hash, | |
| "file_size_mb": round(total_size / (1024*1024), 2) | |
| }) | |
| except requests.exceptions.RequestException as e: | |
| return JSONResponse( | |
| {"error": f"Download failed: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| {"error": f"Processing failed: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| async def ask_question(question: str = Form(...)): | |
| global global_pipeline | |
| if not global_pipeline: | |
| return JSONResponse({"error": "No file uploaded yet."}, status_code=400) | |
| try: | |
| result = await global_pipeline.arun_pipeline(question) | |
| response_text = "" | |
| async for token in result["response"]: | |
| response_text += token | |
| return JSONResponse({ | |
| "answer": response_text, | |
| "cached": result.get("cached", False) | |
| }) | |
| except Exception as e: | |
| return JSONResponse( | |
| {"error": f"Question processing failed: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| async def health_check(): | |
| return JSONResponse({ | |
| "status": "healthy", | |
| "pipeline_loaded": global_pipeline is not None, | |
| "current_file_hash": current_file_hash | |
| }) | |
| async def clear_cache(): | |
| """Clear all caches""" | |
| try: | |
| cache.cleanup_old(max_age=0) | |
| return JSONResponse({"status": "Cache cleared ✅"}) | |
| except Exception as e: | |
| return JSONResponse( | |
| {"error": f"Cache clear failed: {str(e)}"}, | |
| status_code=500 | |
| ) | |
| # ------------------------- | |
| # Run app (for Spaces/Colab/Local) | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Cleanup old cache on startup | |
| cache.cleanup_old(max_age=86400) # 24 hours | |
| uvicorn.run("qwen_app:app", host="0.0.0.0", port=7860, reload=False) |