| | import os |
| | import uuid |
| | from flask import Flask, request, render_template, session, jsonify, Response, stream_with_context |
| | from werkzeug.utils import secure_filename |
| | from rag_processor import create_rag_chain |
| | from typing import Sequence, Any, List |
| | import fitz |
| | import re |
| | import io |
| | from gtts import gTTS |
| | from langchain_core.documents import Document |
| | from langchain_community.document_loaders import TextLoader, Docx2txtLoader |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from langchain_huggingface import HuggingFaceEmbeddings |
| | from langchain_community.vectorstores import FAISS |
| | from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever |
| | from langchain.retrievers.document_compressors.base import BaseDocumentCompressor |
| | from langchain_community.retrievers import BM25Retriever |
| | from langchain_community.chat_message_histories import ChatMessageHistory |
| | from langchain.storage import InMemoryStore |
| | from sentence_transformers.cross_encoder import CrossEncoder |
| |
|
| | app = Flask(__name__) |
| | app.config['SECRET_KEY'] = os.urandom(24) |
| | TEMPERATURE_LABELS = { |
| | '0.2': 'Precise', |
| | '0.4': 'Confident', |
| | '0.6': 'Balanced', |
| | '0.8': 'Flexible', |
| | '1.0': 'Creative', |
| | } |
| | class LocalReranker(BaseDocumentCompressor): |
| | model: Any |
| | top_n: int = 5 |
| |
|
| | class Config: |
| | arbitrary_types_allowed = True |
| |
|
| | def compress_documents(self, documents: Sequence[Document], query: str, |
| | callbacks=None) -> Sequence[Document]: |
| | if not documents: |
| | return [] |
| | pairs = [[query, doc.page_content] for doc in documents] |
| | scores = self.model.predict(pairs, show_progress_bar=False) |
| | doc_scores = list(zip(documents, scores)) |
| | sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], |
| | reverse=True) |
| | top_docs = [] |
| | for (doc, score) in sorted_doc_scores[:self.top_n]: |
| | doc.metadata['rerank_score'] = float(score) |
| | top_docs.append(doc) |
| | return top_docs |
| |
|
| |
|
| | def create_optimized_parent_child_chunks(all_docs): |
| | if not all_docs: |
| | print ('CHUNKING: No input documents provided!') |
| | return ([], [], []) |
| |
|
| | parent_splitter = RecursiveCharacterTextSplitter(chunk_size=900, |
| | chunk_overlap=200, separators=['\n\n', '\n', '. ', '! ', |
| | '? ', '; ', ', ', ' ', '']) |
| | child_splitter = RecursiveCharacterTextSplitter(chunk_size=350, |
| | chunk_overlap=80, separators=['\n', '. ', '! ', '? ', '; ', |
| | ', ', ' ', '']) |
| | parent_docs = parent_splitter.split_documents(all_docs) |
| | doc_ids = [str(uuid.uuid4()) for _ in parent_docs] |
| | child_docs = [] |
| |
|
| | for (i, parent_doc) in enumerate(parent_docs): |
| | parent_id = doc_ids[i] |
| | children = child_splitter.split_documents([parent_doc]) |
| | for (j, child) in enumerate(children): |
| | child.metadata.update({'doc_id': parent_id, |
| | 'chunk_index': j, |
| | 'total_chunks': len(children), |
| | 'is_first_chunk': j == 0, |
| | 'is_last_chunk': j == len(children) |
| | - 1}) |
| | if len(children) > 1: |
| | if j == 0: |
| | child.page_content = '[Beginning] ' + child.page_content |
| | elif j == len(children) - 1: |
| | child.page_content = '[Continues...] ' + child.page_content |
| | child_docs.append(child) |
| |
|
| | print (f"CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks." |
| | ) |
| | return (parent_docs, child_docs, doc_ids) |
| | def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]: |
| | if not docs: |
| | return [] |
| | (parent_scores, child_content_by_parent) = ({}, {}) |
| | for doc in docs: |
| | parent_id = doc.metadata.get('doc_id') |
| | if parent_id: |
| | parent_scores[parent_id] = parent_scores.get(parent_id, 0) \ |
| | + 1 |
| | if parent_id not in child_content_by_parent: |
| | child_content_by_parent[parent_id] = [] |
| | child_content_by_parent[parent_id].append(doc.page_content) |
| |
|
| | parent_ids = list(parent_scores.keys()) |
| | parents = store.mget(parent_ids) |
| | enhanced_parents = [] |
| |
|
| | for (i, parent) in enumerate(parents): |
| | if parent is not None: |
| | parent_id = parent_ids[i] |
| | if parent_id in child_content_by_parent: |
| | child_excerpts = '\n'.join(child_content_by_parent[parent_id][:3]) |
| | enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}" |
| | enhanced_parent =Document(page_content=enhanced_content, |
| | metadata={**parent.metadata, |
| | 'child_relevance_score': parent_scores[parent_id], |
| | 'matching_children': len(child_content_by_parent[parent_id])}) |
| | enhanced_parents.append(enhanced_parent) |
| | else: |
| | print (f"PARENT_FETCH: Parent {parent_ids[i]} not found in store!") |
| |
|
| | enhanced_parents.sort(key=lambda p: p.metadata.get('child_relevance_score', 0), reverse=True) |
| | return enhanced_parents |
| |
|
| |
|
| | is_hf_spaces = bool(os.getenv('SPACE_ID') or os.getenv('SPACES_ZERO_GPU' |
| | )) |
| | app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads' |
| |
|
| | try: |
| | os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
| | print (f"Upload folder ready: {app.config['UPLOAD_FOLDER']}") |
| | except Exception as e: |
| | print (f"Failed to create upload folder, falling back to /tmp: {e}") |
| | app.config['UPLOAD_FOLDER'] = '/tmp/uploads' |
| | os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
| |
|
| | session_data = {} |
| | message_histories = {} |
| |
|
| | print ('Loading embedding model...') |
| | try: |
| | EMBEDDING_MODEL = \ |
| | HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2' |
| | , model_kwargs={'device': 'cpu'}, |
| | encode_kwargs={'normalize_embeddings': True}) |
| | print ('Embedding model loaded.') |
| | except Exception as e: |
| | print (f"FATAL: Could not load embedding model. Error: {e}") |
| | raise e |
| |
|
| | print ('Loading reranker model...') |
| | try: |
| | RERANKER_MODEL = \ |
| | CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', |
| | device='cpu') |
| | print ('Reranker model loaded.') |
| | except Exception as e: |
| | print (f"FATAL: Could not load reranker model. Error: {e}") |
| | raise e |
| |
|
| |
|
| | def load_pdf_with_fallback(filepath): |
| | try: |
| | docs = [] |
| | with fitz.open(filepath) as pdf_doc: |
| | for (page_num, page) in enumerate(pdf_doc): |
| | text = page.get_text() |
| | if text.strip(): |
| | docs.append(Document(page_content=text, |
| | metadata={'source': os.path.basename(filepath), |
| | 'page': page_num + 1})) |
| | if docs: |
| | print (f"Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages" |
| | ) |
| | return docs |
| | else: |
| | raise ValueError('No text content found in PDF.') |
| | except Exception as e: |
| | print (f"PyMuPDF failed for {filepath}: {e}") |
| | raise |
| |
|
| |
|
| | LOADER_MAPPING = {'.txt': TextLoader, '.pdf': load_pdf_with_fallback, |
| | '.docx': Docx2txtLoader} |
| |
|
| |
|
| | def get_session_history(session_id: str) -> ChatMessageHistory: |
| | if session_id not in message_histories: |
| | message_histories[session_id] = ChatMessageHistory() |
| | return message_histories[session_id] |
| |
|
| |
|
| | @app.route('/health', methods=['GET']) |
| | def health_check(): |
| | return (jsonify({'status': 'healthy'}), 200) |
| |
|
| |
|
| | @app.route('/', methods=['GET']) |
| | def index(): |
| | return render_template('index.html') |
| |
|
| |
|
| | @app.route('/upload', methods=['POST']) |
| | def upload_files(): |
| | files = request.files.getlist('file') |
| |
|
| | temperature_str = request.form.get('temperature', '0.2') |
| | temperature = float(temperature_str) |
| | model_name = request.form.get('model_name', |
| | 'moonshotai/kimi-k2-instruct') |
| | print (f"UPLOAD: Model: {model_name}, Temp: {temperature}") |
| |
|
| | if not files or all(f.filename == '' for f in files): |
| | return (jsonify({'status': 'error', |
| | 'message': 'No selected files.'}), 400) |
| |
|
| | (all_docs, processed_files, failed_files) = ([], [], []) |
| | print (f"Processing {len(files)} file(s)...") |
| | for file in files: |
| | if file and file.filename: |
| | filename = secure_filename(file.filename) |
| | filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) |
| | try: |
| | file.save(filepath) |
| | file_ext = os.path.splitext(filename)[1].lower() |
| | if file_ext not in LOADER_MAPPING: |
| | raise ValueError('Unsupported file format.') |
| | loader_func = LOADER_MAPPING[file_ext] |
| | docs = loader_func(filepath) if file_ext == '.pdf' \ |
| | else loader_func(filepath).load() |
| | if not docs: |
| | raise ValueError('No content extracted.') |
| | all_docs.extend(docs) |
| | processed_files.append(filename) |
| | except Exception as e: |
| | print (f"✗ Error processing {filename}: {e}") |
| | failed_files.append(f"{filename} ({e})") |
| |
|
| | if not all_docs: |
| | return (jsonify({'status': 'error', |
| | 'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}" |
| | }), 400) |
| |
|
| | print (f"UPLOAD: Processed {len(processed_files)} files.") |
| | try: |
| | print ('Starting RAG pipeline setup...') |
| | (parent_docs, child_docs, doc_ids) = \ |
| | create_optimized_parent_child_chunks(all_docs) |
| | if not child_docs: |
| | raise ValueError('No child documents created during chunking.') |
| |
|
| | vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL) |
| | store = InMemoryStore() |
| | store.mset(list(zip(doc_ids, parent_docs))) |
| | print (f"Indexed {len(child_docs)} document chunks.") |
| |
|
| | bm25_retriever = BM25Retriever.from_documents(child_docs) |
| | bm25_retriever.k = 12 |
| | faiss_retriever = vectorstore.as_retriever(search_kwargs={'k': 12}) |
| | ensemble_retriever = \ |
| | EnsembleRetriever(retrievers=[bm25_retriever, |
| | faiss_retriever], weights=[0.6, 0.4]) |
| | reranker = LocalReranker(model=RERANKER_MODEL, top_n=5) |
| |
|
| | def get_parents(docs: List[Document]) -> List[Document]: |
| | return get_context_aware_parents(docs, store) |
| |
|
| | compression_retriever = \ |
| | ContextualCompressionRetriever(base_compressor=reranker, |
| | base_retriever=ensemble_retriever) |
| | final_retriever = compression_retriever | get_parents |
| |
|
| | session_id = str(uuid.uuid4()) |
| | (rag_chain, api_key_manager) = \ |
| | create_rag_chain(retriever=final_retriever, |
| | get_session_history_func=get_session_history, |
| | model_name=model_name, |
| | temperature=temperature) |
| |
|
| | session_data[session_id] = {'chain': rag_chain, |
| | 'model_name': model_name, |
| | 'temperature': temperature, |
| | 'api_key_manager': api_key_manager} |
| |
|
| | success_msg = f"Processed: {', '.join(processed_files)}" |
| | if failed_files: |
| | success_msg += f". Failed: {', '.join(failed_files)}" |
| |
|
| | mode_label = TEMPERATURE_LABELS.get(temperature_str, |
| | temperature_str) |
| |
|
| | print (f"UPLOAD COMPLETE: Session {session_id} is ready.") |
| |
|
| | return jsonify({ |
| | 'status': 'success', |
| | 'filename': success_msg, |
| | 'session_id': session_id, |
| | 'model_name': model_name, |
| | 'mode': mode_label, |
| | }) |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | return (jsonify({'status': 'error', |
| | 'message': f'RAG setup failed: {e}'}), 500) |
| |
|
| | @app.route('/chat', methods=['POST', 'GET']) |
| | def chat(): |
| | if request.method == 'GET': |
| | question = request.args.get('question') |
| | session_id = request.args.get('session_id') |
| | print(f"Received GET request for chat: session={session_id}, question={question[:50]}...") |
| | elif request.method == 'POST': |
| | data = request.get_json() |
| | question = data.get('question') |
| | session_id = data.get('session_id') or session.get('session_id') |
| | print(f"Received POST request for chat: session={session_id}, question={question[:50]}...") |
| | else: |
| | return (jsonify({'status': 'error', 'message': 'Method not allowed'}), 405) |
| |
|
| | if not question: |
| | error_msg = "Error: No question provided." |
| | print(f"CHAT Validation Error: {error_msg}") |
| | if request.method == 'GET': |
| | def error_stream(): |
| | yield f'data: {{"error": "{error_msg}"}}\n\n' |
| | return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400) |
| | return jsonify({'status': 'error','message': error_msg}), 400 |
| |
|
| | if not session_id or session_id not in session_data: |
| | error_msg = "Error: Invalid session. Please upload documents first." |
| | print(f"CHAT Validation Error: Invalid session {session_id}.") |
| | if request.method == 'GET': |
| | def error_stream(): |
| | yield f'data: {{"error": "{error_msg}"}}\n\n' |
| | return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400) |
| | return jsonify({'status': 'error', 'message': error_msg }), 400 |
| | try: |
| | session_info = session_data[session_id] |
| | rag_chain = session_info['chain'] |
| | model_name = session_info['model_name'] |
| | temperature_float = session_info['temperature'] |
| | temperature_str = str(temperature_float) |
| | mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str) |
| |
|
| | print (f"CHAT: Streaming response for session {session_id} (Model: {model_name}, Temp: {temperature_float})...") |
| |
|
| | def generate_chunks(): |
| | full_response = '' |
| | try: |
| | stream_iterator = rag_chain.stream({'question': question}, |
| | config={'configurable': {'session_id': session_id}}) |
| |
|
| | for chunk in stream_iterator: |
| | if isinstance(chunk, str): |
| | full_response += chunk |
| | token_escaped = chunk.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') |
| | model_name_escaped = model_name.replace('"', '\\"') |
| | mode_label_escaped = mode_label.replace('"', '\\"') |
| | yield f'data: {{"token": "{token_escaped}", "model_name": "{model_name_escaped}", "mode": "{mode_label_escaped}"}}\n\n' |
| | else: |
| | print(f"Received non-string chunk: {type(chunk)}") |
| |
|
| |
|
| | print ('CHAT: Streaming finished successfully.') |
| |
|
| | except Exception as e: |
| | print(f"CHAT Error during streaming generation: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | error_msg = f"Error during response generation: {str(e)}".replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') |
| | yield f'data: {{"error": "{error_msg}"}}\n\n' |
| | return Response(stream_with_context(generate_chunks()), mimetype='text/event-stream') |
| |
|
| | except Exception as e: |
| | print(f"CHAT Setup Error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | error_msg = f"Error setting up chat stream: {str(e)}" |
| | if request.method == 'GET': |
| | def error_stream(): |
| | clean_error_msg= error_msg.replace("\"", "\\\"").replace("n", "\\n") |
| | yield f'data: {{"error": "{clean_error_msg}"}}\n\n' |
| | return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=500) |
| | return (jsonify({'status': 'error', 'message': error_msg}), 500) |
| |
|
| |
|
| | def clean_markdown_for_tts(text: str) -> str: |
| | text = re.sub(r'\[.*?\]\(.*?\)', '', text) |
| | text = re.sub(r'[`*_#]', '', text) |
| | text = re.sub(r'^\s*[\-\*\+]\s+', '', text, flags=re.MULTILINE) |
| | text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) |
| | text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE) |
| | text = re.sub(r'\n+', ' ', text) |
| | text = re.sub(r'\s{2,}', ' ', text) |
| | return text.strip() |
| |
|
| |
|
| | @app.route('/tts', methods=['POST']) |
| | def text_to_speech(): |
| | data = request.get_json() |
| | text = data.get('text') |
| | if not text: |
| | return (jsonify({'status': 'error', |
| | 'message': 'No text provided.'}), 400) |
| | try: |
| | clean_text = clean_markdown_for_tts(text) |
| | if not clean_text: |
| | return (jsonify({'status': 'error', 'message': 'No speakable text found.'}), 400) |
| |
|
| | tts = gTTS(clean_text, lang='en') |
| | mp3_fp = io.BytesIO() |
| | tts.write_to_fp(mp3_fp) |
| | mp3_fp.seek(0) |
| | return Response(mp3_fp, mimetype='audio/mpeg') |
| | except Exception as e: |
| | print (f"TTS Error: {e}") |
| | return (jsonify({'status': 'error', |
| | 'message': 'Failed to generate audio.'}), 500) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | port = int(os.environ.get('PORT', 7860)) |
| | print (f"Starting Flask app on port {port}") |
| | app.run(host='0.0.0.0', port=port, debug=False, threaded=True) |