Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| from flask import Flask, request, render_template, session, jsonify, Response, stream_with_context | |
| from werkzeug.utils import secure_filename | |
| from rag_processor import create_rag_chain | |
| from typing import Sequence, Any, List | |
| import fitz | |
| import re | |
| import io | |
| from gtts import gTTS | |
| from langchain_core.documents import Document | |
| from langchain_community.document_loaders import TextLoader, Docx2txtLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever | |
| from langchain.retrievers.document_compressors.base import BaseDocumentCompressor | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain.storage import InMemoryStore | |
| from sentence_transformers.cross_encoder import CrossEncoder | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = os.urandom(24) | |
| TEMPERATURE_LABELS = { | |
| '0.2': 'Precise', | |
| '0.4': 'Confident', | |
| '0.6': 'Balanced', | |
| '0.8': 'Flexible', | |
| '1.0': 'Creative', | |
| } | |
| class LocalReranker(BaseDocumentCompressor): | |
| model: Any | |
| top_n: int = 5 | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def compress_documents(self, documents: Sequence[Document], query: str, | |
| callbacks=None) -> Sequence[Document]: | |
| if not documents: | |
| return [] | |
| pairs = [[query, doc.page_content] for doc in documents] | |
| scores = self.model.predict(pairs, show_progress_bar=False) | |
| doc_scores = list(zip(documents, scores)) | |
| sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], | |
| reverse=True) | |
| top_docs = [] | |
| for (doc, score) in sorted_doc_scores[:self.top_n]: | |
| doc.metadata['rerank_score'] = float(score) | |
| top_docs.append(doc) | |
| return top_docs | |
| def create_optimized_parent_child_chunks(all_docs): | |
| if not all_docs: | |
| print ('CHUNKING: No input documents provided!') | |
| return ([], [], []) | |
| parent_splitter = RecursiveCharacterTextSplitter(chunk_size=900, | |
| chunk_overlap=200, separators=['\n\n', '\n', '. ', '! ', | |
| '? ', '; ', ', ', ' ', '']) | |
| child_splitter = RecursiveCharacterTextSplitter(chunk_size=350, | |
| chunk_overlap=80, separators=['\n', '. ', '! ', '? ', '; ', | |
| ', ', ' ', '']) | |
| parent_docs = parent_splitter.split_documents(all_docs) | |
| doc_ids = [str(uuid.uuid4()) for _ in parent_docs] | |
| child_docs = [] | |
| for (i, parent_doc) in enumerate(parent_docs): | |
| parent_id = doc_ids[i] | |
| children = child_splitter.split_documents([parent_doc]) | |
| for (j, child) in enumerate(children): | |
| child.metadata.update({'doc_id': parent_id, | |
| 'chunk_index': j, | |
| 'total_chunks': len(children), | |
| 'is_first_chunk': j == 0, | |
| 'is_last_chunk': j == len(children) | |
| - 1}) | |
| if len(children) > 1: | |
| if j == 0: | |
| child.page_content = '[Beginning] ' + child.page_content | |
| elif j == len(children) - 1: | |
| child.page_content = '[Continues...] ' + child.page_content | |
| child_docs.append(child) | |
| print (f"CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks." | |
| ) | |
| return (parent_docs, child_docs, doc_ids) | |
| def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]: | |
| if not docs: | |
| return [] | |
| (parent_scores, child_content_by_parent) = ({}, {}) | |
| for doc in docs: | |
| parent_id = doc.metadata.get('doc_id') | |
| if parent_id: | |
| parent_scores[parent_id] = parent_scores.get(parent_id, 0) \ | |
| + 1 | |
| if parent_id not in child_content_by_parent: | |
| child_content_by_parent[parent_id] = [] | |
| child_content_by_parent[parent_id].append(doc.page_content) | |
| parent_ids = list(parent_scores.keys()) | |
| parents = store.mget(parent_ids) | |
| enhanced_parents = [] | |
| for (i, parent) in enumerate(parents): | |
| if parent is not None: | |
| parent_id = parent_ids[i] | |
| if parent_id in child_content_by_parent: | |
| child_excerpts = '\n'.join(child_content_by_parent[parent_id][:3]) | |
| enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}" | |
| enhanced_parent =Document(page_content=enhanced_content, | |
| metadata={**parent.metadata, | |
| 'child_relevance_score': parent_scores[parent_id], | |
| 'matching_children': len(child_content_by_parent[parent_id])}) | |
| enhanced_parents.append(enhanced_parent) | |
| else: | |
| print (f"PARENT_FETCH: Parent {parent_ids[i]} not found in store!") | |
| enhanced_parents.sort(key=lambda p: p.metadata.get('child_relevance_score', 0), reverse=True) | |
| return enhanced_parents | |
| is_hf_spaces = bool(os.getenv('SPACE_ID') or os.getenv('SPACES_ZERO_GPU' | |
| )) | |
| app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads' | |
| try: | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| print (f"Upload folder ready: {app.config['UPLOAD_FOLDER']}") | |
| except Exception as e: | |
| print (f"Failed to create upload folder, falling back to /tmp: {e}") | |
| app.config['UPLOAD_FOLDER'] = '/tmp/uploads' | |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) | |
| session_data = {} | |
| message_histories = {} | |
| print ('Loading embedding model...') | |
| try: | |
| EMBEDDING_MODEL = \ | |
| HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2' | |
| , model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True}) | |
| print ('Embedding model loaded.') | |
| except Exception as e: | |
| print (f"FATAL: Could not load embedding model. Error: {e}") | |
| raise e | |
| print ('Loading reranker model...') | |
| try: | |
| RERANKER_MODEL = \ | |
| CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', | |
| device='cpu') | |
| print ('Reranker model loaded.') | |
| except Exception as e: | |
| print (f"FATAL: Could not load reranker model. Error: {e}") | |
| raise e | |
| def load_pdf_with_fallback(filepath): | |
| try: | |
| docs = [] | |
| with fitz.open(filepath) as pdf_doc: | |
| for (page_num, page) in enumerate(pdf_doc): | |
| text = page.get_text() | |
| if text.strip(): | |
| docs.append(Document(page_content=text, | |
| metadata={'source': os.path.basename(filepath), | |
| 'page': page_num + 1})) | |
| if docs: | |
| print (f"Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages" | |
| ) | |
| return docs | |
| else: | |
| raise ValueError('No text content found in PDF.') | |
| except Exception as e: | |
| print (f"PyMuPDF failed for {filepath}: {e}") | |
| raise | |
| LOADER_MAPPING = {'.txt': TextLoader, '.pdf': load_pdf_with_fallback, | |
| '.docx': Docx2txtLoader} | |
| def get_session_history(session_id: str) -> ChatMessageHistory: | |
| if session_id not in message_histories: | |
| message_histories[session_id] = ChatMessageHistory() | |
| return message_histories[session_id] | |
| def health_check(): | |
| return (jsonify({'status': 'healthy'}), 200) | |
| def index(): | |
| return render_template('index.html') | |
| def upload_files(): | |
| files = request.files.getlist('file') | |
| temperature_str = request.form.get('temperature', '0.2') | |
| temperature = float(temperature_str) | |
| model_name = request.form.get('model_name', | |
| 'moonshotai/kimi-k2-instruct') | |
| print (f"UPLOAD: Model: {model_name}, Temp: {temperature}") | |
| if not files or all(f.filename == '' for f in files): | |
| return (jsonify({'status': 'error', | |
| 'message': 'No selected files.'}), 400) | |
| (all_docs, processed_files, failed_files) = ([], [], []) | |
| print (f"Processing {len(files)} file(s)...") | |
| for file in files: | |
| if file and file.filename: | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| try: | |
| file.save(filepath) | |
| file_ext = os.path.splitext(filename)[1].lower() | |
| if file_ext not in LOADER_MAPPING: | |
| raise ValueError('Unsupported file format.') | |
| loader_func = LOADER_MAPPING[file_ext] | |
| docs = loader_func(filepath) if file_ext == '.pdf' \ | |
| else loader_func(filepath).load() | |
| if not docs: | |
| raise ValueError('No content extracted.') | |
| all_docs.extend(docs) | |
| processed_files.append(filename) | |
| except Exception as e: | |
| print (f"✗ Error processing {filename}: {e}") | |
| failed_files.append(f"{filename} ({e})") | |
| if not all_docs: | |
| return (jsonify({'status': 'error', | |
| 'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}" | |
| }), 400) | |
| print (f"UPLOAD: Processed {len(processed_files)} files.") | |
| try: | |
| print ('Starting RAG pipeline setup...') | |
| (parent_docs, child_docs, doc_ids) = \ | |
| create_optimized_parent_child_chunks(all_docs) | |
| if not child_docs: | |
| raise ValueError('No child documents created during chunking.') | |
| vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL) | |
| store = InMemoryStore() | |
| store.mset(list(zip(doc_ids, parent_docs))) | |
| print (f"Indexed {len(child_docs)} document chunks.") | |
| bm25_retriever = BM25Retriever.from_documents(child_docs) | |
| bm25_retriever.k = 12 | |
| faiss_retriever = vectorstore.as_retriever(search_kwargs={'k': 12}) | |
| ensemble_retriever = \ | |
| EnsembleRetriever(retrievers=[bm25_retriever, | |
| faiss_retriever], weights=[0.6, 0.4]) | |
| reranker = LocalReranker(model=RERANKER_MODEL, top_n=5) | |
| def get_parents(docs: List[Document]) -> List[Document]: | |
| return get_context_aware_parents(docs, store) | |
| compression_retriever = \ | |
| ContextualCompressionRetriever(base_compressor=reranker, | |
| base_retriever=ensemble_retriever) | |
| final_retriever = compression_retriever | get_parents | |
| session_id = str(uuid.uuid4()) | |
| (rag_chain, api_key_manager) = \ | |
| create_rag_chain(retriever=final_retriever, | |
| get_session_history_func=get_session_history, | |
| model_name=model_name, | |
| temperature=temperature) | |
| session_data[session_id] = {'chain': rag_chain, | |
| 'model_name': model_name, | |
| 'temperature': temperature, | |
| 'api_key_manager': api_key_manager} | |
| success_msg = f"Processed: {', '.join(processed_files)}" | |
| if failed_files: | |
| success_msg += f". Failed: {', '.join(failed_files)}" | |
| mode_label = TEMPERATURE_LABELS.get(temperature_str, | |
| temperature_str) | |
| print (f"UPLOAD COMPLETE: Session {session_id} is ready.") | |
| return jsonify({ | |
| 'status': 'success', | |
| 'filename': success_msg, | |
| 'session_id': session_id, | |
| 'model_name': model_name, | |
| 'mode': mode_label, | |
| }) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return (jsonify({'status': 'error', | |
| 'message': f'RAG setup failed: {e}'}), 500) | |
| def chat(): | |
| if request.method == 'GET': | |
| question = request.args.get('question') | |
| session_id = request.args.get('session_id') | |
| print(f"Received GET request for chat: session={session_id}, question={question[:50]}...") | |
| elif request.method == 'POST': | |
| data = request.get_json() | |
| question = data.get('question') | |
| session_id = data.get('session_id') or session.get('session_id') | |
| print(f"Received POST request for chat: session={session_id}, question={question[:50]}...") | |
| else: | |
| return (jsonify({'status': 'error', 'message': 'Method not allowed'}), 405) | |
| if not question: | |
| error_msg = "Error: No question provided." | |
| print(f"CHAT Validation Error: {error_msg}") | |
| if request.method == 'GET': | |
| def error_stream(): | |
| yield f'data: {{"error": "{error_msg}"}}\n\n' | |
| return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400) | |
| return jsonify({'status': 'error','message': error_msg}), 400 | |
| if not session_id or session_id not in session_data: | |
| error_msg = "Error: Invalid session. Please upload documents first." | |
| print(f"CHAT Validation Error: Invalid session {session_id}.") | |
| if request.method == 'GET': | |
| def error_stream(): | |
| yield f'data: {{"error": "{error_msg}"}}\n\n' | |
| return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=400) | |
| return jsonify({'status': 'error', 'message': error_msg }), 400 | |
| try: | |
| session_info = session_data[session_id] | |
| rag_chain = session_info['chain'] | |
| model_name = session_info['model_name'] | |
| temperature_float = session_info['temperature'] | |
| temperature_str = str(temperature_float) | |
| mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str) | |
| print (f"CHAT: Streaming response for session {session_id} (Model: {model_name}, Temp: {temperature_float})...") | |
| def generate_chunks(): | |
| full_response = '' | |
| try: | |
| stream_iterator = rag_chain.stream({'question': question}, | |
| config={'configurable': {'session_id': session_id}}) | |
| for chunk in stream_iterator: | |
| if isinstance(chunk, str): | |
| full_response += chunk | |
| token_escaped = chunk.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') | |
| model_name_escaped = model_name.replace('"', '\\"') | |
| mode_label_escaped = mode_label.replace('"', '\\"') | |
| yield f'data: {{"token": "{token_escaped}", "model_name": "{model_name_escaped}", "mode": "{mode_label_escaped}"}}\n\n' | |
| else: | |
| print(f"Received non-string chunk: {type(chunk)}") | |
| print ('CHAT: Streaming finished successfully.') | |
| except Exception as e: | |
| print(f"CHAT Error during streaming generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| error_msg = f"Error during response generation: {str(e)}".replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') | |
| yield f'data: {{"error": "{error_msg}"}}\n\n' | |
| return Response(stream_with_context(generate_chunks()), mimetype='text/event-stream') | |
| except Exception as e: | |
| print(f"CHAT Setup Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| error_msg = f"Error setting up chat stream: {str(e)}" | |
| if request.method == 'GET': | |
| def error_stream(): | |
| clean_error_msg= error_msg.replace("\"", "\\\"").replace("n", "\\n") | |
| yield f'data: {{"error": "{clean_error_msg}"}}\n\n' | |
| return Response(stream_with_context(error_stream()), mimetype='text/event-stream', status=500) | |
| return (jsonify({'status': 'error', 'message': error_msg}), 500) | |
| def clean_markdown_for_tts(text: str) -> str: | |
| text = re.sub(r'\[.*?\]\(.*?\)', '', text) | |
| text = re.sub(r'[`*_#]', '', text) | |
| text = re.sub(r'^\s*[\-\*\+]\s+', '', text, flags=re.MULTILINE) | |
| text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) | |
| text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE) | |
| text = re.sub(r'\n+', ' ', text) | |
| text = re.sub(r'\s{2,}', ' ', text) | |
| return text.strip() | |
| def text_to_speech(): | |
| data = request.get_json() | |
| text = data.get('text') | |
| if not text: | |
| return (jsonify({'status': 'error', | |
| 'message': 'No text provided.'}), 400) | |
| try: | |
| clean_text = clean_markdown_for_tts(text) | |
| if not clean_text: | |
| return (jsonify({'status': 'error', 'message': 'No speakable text found.'}), 400) | |
| tts = gTTS(clean_text, lang='en') | |
| mp3_fp = io.BytesIO() | |
| tts.write_to_fp(mp3_fp) | |
| mp3_fp.seek(0) | |
| return Response(mp3_fp, mimetype='audio/mpeg') | |
| except Exception as e: | |
| print (f"TTS Error: {e}") | |
| return (jsonify({'status': 'error', | |
| 'message': 'Failed to generate audio.'}), 500) | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) | |
| print (f"Starting Flask app on port {port}") | |
| app.run(host='0.0.0.0', port=port, debug=False, threaded=True) |