BudgetBot / qwen_app.py
Chibueze-Kingsley's picture
Update qwen_app.py
819df1b verified
import os
import tempfile
import shutil
import hashlib
import pickle
import sqlite3
import json
import time
import threading
from typing import List, Any
from pathlib import Path
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader, PyPDFLoader
from langchain.docstore.document import Document
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import JSONResponse
import requests
# -------------------------
# Configuration
# -------------------------
CACHE_DIR = Path(tempfile.gettempdir()) / "budgetbot_cache"
CACHE_DIR.mkdir(exist_ok=True)
DB_PATH = CACHE_DIR / "cache.db"
# File size limits (in bytes)
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
MIN_TEXT_LENGTH = 100 # Minimum text length after extraction
# -------------------------
# SQLite Cache
# -------------------------
class SimpleCache:
def __init__(self, db_path=DB_PATH):
self.conn = sqlite3.connect(str(db_path), check_same_thread=False)
self.lock = threading.Lock()
self.conn.execute('''
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value TEXT,
timestamp REAL
)
''')
self.conn.execute('''
CREATE TABLE IF NOT EXISTS vectorstore_cache (
file_hash TEXT PRIMARY KEY,
data BLOB,
timestamp REAL
)
''')
self.conn.commit()
def get(self, key):
with self.lock:
cursor = self.conn.execute(
'SELECT value FROM cache WHERE key = ?', (key,)
)
row = cursor.fetchone()
return json.loads(row[0]) if row else None
def set(self, key, value, ttl=3600):
with self.lock:
self.conn.execute(
'INSERT OR REPLACE INTO cache VALUES (?, ?, ?)',
(key, json.dumps(value), time.time())
)
self.conn.commit()
def get_vectorstore(self, file_hash):
with self.lock:
cursor = self.conn.execute(
'SELECT data FROM vectorstore_cache WHERE file_hash = ?', (file_hash,)
)
row = cursor.fetchone()
if row:
try:
return pickle.loads(row[0])
except Exception as e:
print(f"Failed to load vectorstore from cache: {e}")
return None
return None
def set_vectorstore(self, file_hash, vectorstore):
with self.lock:
try:
data = pickle.dumps(vectorstore)
self.conn.execute(
'INSERT OR REPLACE INTO vectorstore_cache VALUES (?, ?, ?)',
(file_hash, data, time.time())
)
self.conn.commit()
except Exception as e:
print(f"Cache error: {e}")
def cleanup_old(self, max_age=86400):
"""Remove cache entries older than max_age seconds"""
with self.lock:
cutoff = time.time() - max_age
self.conn.execute('DELETE FROM cache WHERE timestamp < ?', (cutoff,))
self.conn.execute('DELETE FROM vectorstore_cache WHERE timestamp < ?', (cutoff,))
self.conn.commit()
# Initialize cache
cache = SimpleCache()
# -------------------------
# Embedding Model (Cached)
# -------------------------
_embeddings_lock = threading.Lock()
_embeddings_cache = None
def get_embeddings():
global _embeddings_cache
if _embeddings_cache is None:
with _embeddings_lock:
if _embeddings_cache is None:
_embeddings_cache = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L12-v2",
model_kwargs={'device': 'cpu'}
)
return _embeddings_cache
# -------------------------
# File processing
# -------------------------
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=100,
separators=["\n\n", "\n", ". ", " ", ""],
length_function=len,
)
def get_file_hash(file_path: str) -> str:
"""Generate SHA256 hash of file"""
sha256 = hashlib.sha256()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
sha256.update(chunk)
return sha256.hexdigest()
def process_file_path(file_path: str) -> List[Document]:
"""Load and split PDF or TXT into LangChain Documents."""
try:
if file_path.lower().endswith(".pdf"):
loader = PyPDFLoader(file_path)
else:
loader = TextLoader(file_path, encoding='utf-8')
docs = loader.load()
# Check if documents have any text content
total_text = "".join([doc.page_content for doc in docs]).strip()
if not total_text or len(total_text) < MIN_TEXT_LENGTH:
raise ValueError(
f"Insufficient text extracted from file. "
f"Got {len(total_text)} characters. "
f"This PDF may contain only images or scanned content. "
f"Please provide a text-based PDF or use OCR."
)
# Split documents
split_docs = text_splitter.split_documents(docs)
# Filter out empty chunks
split_docs = [doc for doc in split_docs if doc.page_content.strip()]
if not split_docs:
raise ValueError("No valid text chunks after processing")
print(f"Processed {len(split_docs)} text chunks from file")
return split_docs
except Exception as e:
print(f"Error processing file: {e}")
raise
# -------------------------
# LLM (Cached)
# -------------------------
_llm_cache = None
_llm_lock = threading.Lock()
def get_llm():
global _llm_cache
if _llm_cache is None:
with _llm_lock:
if _llm_cache is None:
_llm_cache = ChatOpenAI(
model="qwen/qwen-2.5-7b-instruct",
streaming=True,
temperature=0,
max_tokens=512,
openai_api_base=os.environ.get("OPENAI_API_BASE", "https://openrouter.ai/api/v1"),
openai_api_key=os.environ.get("OPENROUTER_API_KEY")
)
return _llm_cache
# -------------------------
# Retrieval QA Pipeline
# -------------------------
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: Any, vectorstore: FAISS) -> None:
self.llm = llm
self.vectorstore = vectorstore
system_template = (
"You are a helpful assistant. "
"Use the following context to answer a user's question. "
"If the context does not contain the answer, reply with 'I don't know'."
)
self.prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("Context:\n{context}\n\nQuestion:\n{question}")
])
def _get_cache_key(self, user_query: str) -> str:
"""Generate cache key for query"""
return f"qa_{hashlib.md5(user_query.encode()).hexdigest()}"
async def arun_pipeline(self, user_query: str):
# Check cache first
cache_key = self._get_cache_key(user_query)
cached_response = cache.get(cache_key)
if cached_response:
async def cached_generator():
yield cached_response['answer']
return {"response": cached_generator(), "context": cached_response.get('context', []), "cached": True}
# Retrieve documents
docs = self.vectorstore.similarity_search(user_query, k=4)
context_text = "\n".join([doc.page_content for doc in docs])
messages = self.prompt.format_messages(context=context_text, question=user_query)
# Generate response
full_response = ""
async def generate_response():
nonlocal full_response
async for chunk in self.llm.astream(messages):
content = chunk.content if chunk.content else ""
full_response += content
yield content
result = {"response": generate_response(), "context": docs, "cached": False}
# Cache after streaming (in background)
async def cache_after_stream():
import asyncio
await asyncio.sleep(0.5)
if full_response:
cache.set(cache_key, {
'answer': full_response,
'context': [{'page_content': doc.page_content} for doc in docs]
})
import asyncio
asyncio.create_task(cache_after_stream())
return result
# -------------------------
# FastAPI (API Mode)
# -------------------------
app = FastAPI()
global_pipeline = None
current_file_hash = None
@app.post("/upload/")
async def upload_file(file: UploadFile):
global global_pipeline, current_file_hash
try:
if not file or not file.filename:
return JSONResponse({"error": "No file provided"}, status_code=400)
# Check file extension
if not (file.filename.lower().endswith('.pdf') or file.filename.lower().endswith('.txt')):
return JSONResponse(
{"error": "Only PDF and TXT files are supported"},
status_code=400
)
# Read file content
content = await file.read()
# Check file size
if len(content) > MAX_FILE_SIZE:
return JSONResponse(
{"error": f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB"},
status_code=400
)
if len(content) == 0:
return JSONResponse({"error": "Empty file"}, status_code=400)
# Save uploaded file
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.split('.')[-1]}") as tmp:
tmp.write(content)
tmp_path = tmp.name
# Check if file already processed
file_hash = get_file_hash(tmp_path)
# Try to get from cache
vectorstore = cache.get_vectorstore(file_hash)
cached = vectorstore is not None
if vectorstore is None:
# Process file
try:
texts = process_file_path(tmp_path)
if not texts:
os.unlink(tmp_path)
return JSONResponse(
{"error": "No text content found in file. PDF may contain only images."},
status_code=400
)
embeddings = get_embeddings()
vectorstore = FAISS.from_documents(texts, embeddings)
# Cache vectorstore
cache.set_vectorstore(file_hash, vectorstore)
except ValueError as ve:
os.unlink(tmp_path)
return JSONResponse({"error": str(ve)}, status_code=400)
except Exception as e:
os.unlink(tmp_path)
return JSONResponse(
{"error": f"Failed to process file: {str(e)}"},
status_code=500
)
# Create pipeline
chat_llm = get_llm()
global_pipeline = RetrievalAugmentedQAPipeline(llm=chat_llm, vectorstore=vectorstore)
current_file_hash = file_hash
# Cleanup temp file
try:
os.unlink(tmp_path)
except:
pass
return JSONResponse({
"status": "File uploaded and processed ✅",
"filename": file.filename,
"cached": cached,
"file_hash": file_hash,
"file_size_mb": round(len(content) / (1024*1024), 2)
})
except Exception as e:
return JSONResponse(
{"error": f"Upload failed: {str(e)}"},
status_code=500
)
@app.post("/upload_url/")
async def upload_file_url(file_url: str = Form(...)):
global global_pipeline, current_file_hash
try:
# Download file from URL
response = requests.get(file_url, stream=True, timeout=60)
if response.status_code != 200:
return JSONResponse(
{"error": f"Failed to download file: {response.status_code}"},
status_code=400
)
filename = file_url.split("/")[-1] or "downloaded_file.pdf"
# Check file extension
if not (filename.lower().endswith('.pdf') or filename.lower().endswith('.txt')):
return JSONResponse(
{"error": "Only PDF and TXT files are supported"},
status_code=400
)
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{filename.split('.')[-1]}") as tmp:
total_size = 0
for chunk in response.iter_content(chunk_size=8192):
total_size += len(chunk)
if total_size > MAX_FILE_SIZE:
os.unlink(tmp.name)
return JSONResponse(
{"error": f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB"},
status_code=400
)
tmp.write(chunk)
tmp_path = tmp.name
# Check if file already processed
file_hash = get_file_hash(tmp_path)
# Try to get from cache
vectorstore = cache.get_vectorstore(file_hash)
cached = vectorstore is not None
if vectorstore is None:
# Process file
try:
texts = process_file_path(tmp_path)
if not texts:
os.unlink(tmp_path)
return JSONResponse(
{"error": "No text content found in file. PDF may contain only images."},
status_code=400
)
embeddings = get_embeddings()
vectorstore = FAISS.from_documents(texts, embeddings)
# Cache vectorstore
cache.set_vectorstore(file_hash, vectorstore)
except ValueError as ve:
os.unlink(tmp_path)
return JSONResponse({"error": str(ve)}, status_code=400)
except Exception as e:
os.unlink(tmp_path)
return JSONResponse(
{"error": f"Failed to process file: {str(e)}"},
status_code=500
)
# Create pipeline
chat_llm = get_llm()
global_pipeline = RetrievalAugmentedQAPipeline(llm=chat_llm, vectorstore=vectorstore)
current_file_hash = file_hash
# Cleanup temp file
try:
os.unlink(tmp_path)
except:
pass
return JSONResponse({
"status": "File downloaded and processed ✅",
"filename": filename,
"cached": cached,
"file_hash": file_hash,
"file_size_mb": round(total_size / (1024*1024), 2)
})
except requests.exceptions.RequestException as e:
return JSONResponse(
{"error": f"Download failed: {str(e)}"},
status_code=500
)
except Exception as e:
return JSONResponse(
{"error": f"Processing failed: {str(e)}"},
status_code=500
)
@app.post("/ask/")
async def ask_question(question: str = Form(...)):
global global_pipeline
if not global_pipeline:
return JSONResponse({"error": "No file uploaded yet."}, status_code=400)
try:
result = await global_pipeline.arun_pipeline(question)
response_text = ""
async for token in result["response"]:
response_text += token
return JSONResponse({
"answer": response_text,
"cached": result.get("cached", False)
})
except Exception as e:
return JSONResponse(
{"error": f"Question processing failed: {str(e)}"},
status_code=500
)
@app.get("/health")
async def health_check():
return JSONResponse({
"status": "healthy",
"pipeline_loaded": global_pipeline is not None,
"current_file_hash": current_file_hash
})
@app.post("/clear_cache/")
async def clear_cache():
"""Clear all caches"""
try:
cache.cleanup_old(max_age=0)
return JSONResponse({"status": "Cache cleared ✅"})
except Exception as e:
return JSONResponse(
{"error": f"Cache clear failed: {str(e)}"},
status_code=500
)
# -------------------------
# Run app (for Spaces/Colab/Local)
# -------------------------
if __name__ == "__main__":
import uvicorn
# Cleanup old cache on startup
cache.cleanup_old(max_age=86400) # 24 hours
uvicorn.run("qwen_app:app", host="0.0.0.0", port=7860, reload=False)