| import os |
| import requests |
| import pdfplumber |
| from sentence_transformers import SentenceTransformer |
| import faiss |
| import numpy as np |
| from gpt4all import GPT4All |
| from flask import Flask, request, jsonify, send_from_directory |
| from flask_cors import CORS |
| import logging |
| from dotenv import load_dotenv |
| from groq import Groq |
| import uuid |
| import pickle |
| import json |
|
|
| |
| load_dotenv() |
| app = Flask(__name__, static_folder='../', static_url_path='/') |
| |
| |
| CORS(app, resources={r"/api/*": {"origins": "*"}}) |
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
| |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
| JOBS_DIR = os.path.join(os.path.dirname(__file__), 'jobs') |
| SESSIONS_DIR = os.path.join(os.path.dirname(__file__), 'sessions') |
| os.makedirs(SESSIONS_DIR, exist_ok=True) |
| os.makedirs(JOBS_DIR, exist_ok=True) |
|
|
| embedding_model = None |
| local_llm = None |
| api_llm = None |
|
|
| |
| http_session = requests.Session() |
| http_session.headers.update({ |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' |
| }) |
|
|
| |
| |
| HF_CACHE_DIR = os.getenv("HF_HOME", os.path.join("/app", ".cache", "huggingface")) |
|
|
| def get_embedding_model(): |
| """Lazy-loads the embedding model to speed up initial server start.""" |
| global embedding_model |
| if embedding_model is None: |
| logging.info("Loading embedding model for the first time...") |
| try: |
| |
| embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=HF_CACHE_DIR) |
| logging.info(f"Embedding model loaded. Cache folder: {HF_CACHE_DIR}") |
| except Exception as e: |
| logging.error(f"CRITICAL ERROR: Failed to load embedding model from {HF_CACHE_DIR}. " |
| f"Please check permissions and ensure the model can be downloaded/accessed. Error: {e}") |
| raise |
| return embedding_model |
|
|
| def get_local_llm(): |
| """Lazy-loads the local GPT4All model, downloading if necessary.""" |
| global local_llm |
| if local_llm is None: |
| logging.info("Initializing local LLM for fallback...") |
| MODELS_DIR = os.path.join(os.path.dirname(__file__), 'models') |
| MODEL_NAME = "orca-mini-3b-gguf2-q4_0.gguf" |
| GPT_MODEL_PATH = os.path.join(MODELS_DIR, MODEL_NAME) |
|
|
| |
| |
| if not os.path.exists(GPT_MODEL_PATH): |
| error_message = f"Local model not found at {GPT_MODEL_PATH}. It should have been included in the Docker build." |
| logging.error(error_message) |
| |
| raise FileNotFoundError(error_message) |
| |
| logging.info(f"Loading GPT4All model from: {GPT_MODEL_PATH}") |
| local_llm = GPT4All(GPT_MODEL_PATH) |
| logging.info("Local LLM loaded.") |
| return local_llm |
|
|
| |
| def chunk_text(text, chunk_size=500, overlap=50): |
| words = text.split() |
| return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size - overlap)] |
|
|
| |
|
|
| @app.route('/') |
| def serve_index(): |
| return send_from_directory(app.static_folder, 'index.html') |
|
|
| @app.route('/api/search', methods=['POST']) |
| def search_nasa_api(): |
| """ |
| Searches the NASA NTRS API for papers. |
| """ |
| data = request.get_json() |
| keyword = data.get('keyword') |
| if not keyword: |
| return jsonify({"error": "Keyword is required"}), 400 |
|
|
| logging.info(f"Searching NTRS for: {keyword}") |
| |
| search_url = f"https://ntrs.nasa.gov/api/citations/search?q={keyword}&limit=15" |
| try: |
| response = http_session.get(search_url) |
| response.raise_for_status() |
| results = response.json().get('results', []) |
| |
| |
| formatted_results = [] |
| for item in results: |
| |
| |
| pdf_link = next((d.get('links', {}).get('pdf') for d in item.get('downloads', []) if d.get('mimetype') == 'application/pdf'), None) |
| if pdf_link: |
| formatted_results.append({ |
| 'id': item.get('id'), |
| 'title': item.get('title'), |
| 'abstract': item.get('abstract'), |
| 'pdfLink': f"https://ntrs.nasa.gov{pdf_link}" |
| }) |
|
|
| return jsonify(formatted_results) |
| except requests.RequestException as e: |
| logging.error(f"NTRS API request failed: {e}") |
| return jsonify({"error": "Failed to fetch data from NASA NTRS API"}), 500 |
|
|
| def _run_training_job(job_id, papers): |
| """The actual long-running training logic, using GCS if configured.""" |
| job_file = os.path.join(JOBS_DIR, f"{job_id}.json") |
|
|
| def update_status(status, message, extra_data=None): |
| """Updates job status in a file (local or GCS).""" |
| job_data = {"status": status, "message": message, **(extra_data or {})} |
| with open(job_file, 'w') as f: json.dump(job_data, f) |
|
|
| def save_session_data(session_id, index, metadata): |
| """Saves FAISS index and metadata to GCS or local.""" |
| session_path = os.path.join(SESSIONS_DIR, session_id) |
| os.makedirs(session_path, exist_ok=True) |
| faiss.write_index(index, os.path.join(session_path, "index.faiss")) |
| with open(os.path.join(session_path, "metadata.pkl"), "wb") as f: |
| pickle.dump(metadata, f) |
|
|
| try: |
| update_status("processing", f"Training on {len(papers)} papers...") |
| logging.info(f"Starting background training for job {job_id} on {len(papers)} papers.") |
| |
| if not papers: |
| raise ValueError("List of papers is required for training.") |
|
|
| |
| papers_folder = "papers" |
| os.makedirs(papers_folder, exist_ok=True) |
|
|
| all_chunks = [] |
| chunk_metadata = [] |
| for i, paper in enumerate(papers): |
| url = paper.get('url') |
| paper_id = paper.get('id') |
| paper_title = paper.get('title', 'Unknown Title') |
| |
| update_status("processing", f"Processing paper {i+1}/{len(papers)}: {paper_title[:30]}...") |
| try: |
| filename = f"{paper_id}.pdf" |
| filepath = os.path.join(papers_folder, filename) |
|
|
| |
| if not os.path.exists(filepath): |
| logging.info(f"Downloading {url}...") |
| response = http_session.get(url, stream=True) |
| response.raise_for_status() |
| with open(filepath, 'wb') as f: |
| for chunk in response.iter_content(chunk_size=8192): |
| f.write(chunk) |
| |
| |
| with pdfplumber.open(filepath) as pdf: |
| text = "".join(page.extract_text() + "\n" for page in pdf.pages if page.extract_text()) |
| |
| |
| paper_chunks = chunk_text(text) |
| all_chunks.extend(paper_chunks) |
| chunk_metadata.extend([{'id': paper_id, 'title': paper_title}] * len(paper_chunks)) |
|
|
| except Exception as e: |
| logging.error(f"Failed to process paper {url}: {e}") |
| continue |
|
|
| if not all_chunks: |
| raise ValueError("Could not process any of the selected papers.") |
|
|
| |
| update_status("processing", "Embedding text... This is the slowest step.") |
| |
| |
| embeddings = get_embedding_model().encode(all_chunks, show_progress_bar=False, batch_size=32) |
| |
| dimension = embeddings.shape[1] |
| index = faiss.IndexFlatL2(dimension) |
| index.add(np.array(embeddings)) |
|
|
| |
| session_id = str(uuid.uuid4()) |
| save_session_data(session_id, index, {"chunks": all_chunks, "chunk_metadata": chunk_metadata}) |
| logging.info(f"Training complete for job {job_id}. Session saved to {session_id}.") |
| |
| update_status("complete", f"Chatbot is ready! Trained on {len(papers)} papers.", {"sessionId": session_id}) |
|
|
| except Exception as e: |
| logging.error(f"Background job {job_id} failed: {e}") |
| update_status("error", str(e)) |
|
|
| @app.route('/api/train', methods=['POST']) |
| def train_chatbot(): |
| """ |
| Kicks off the training process in a background thread. |
| """ |
| data = request.get_json() |
| papers = data.get('papers') |
| if not papers: |
| return jsonify({"error": "List of papers is required"}), 400 |
|
|
| job_id = str(uuid.uuid4()) |
| initial_status = {"status": "starting", "message": "Initializing training job..."} |
| job_file = os.path.join(JOBS_DIR, f"{job_id}.json") |
| with open(job_file, 'w') as f: |
| json.dump(initial_status, f) |
|
|
| |
| import threading |
| thread = threading.Thread(target=_run_training_job, args=(job_id, papers)) |
| thread.start() |
|
|
| |
| return jsonify({"message": "Training started.", "jobId": job_id}), 202 |
|
|
| @app.route('/api/train/status/<job_id>', methods=['GET']) |
| def get_training_status(job_id): |
| """ |
| Checks the status of a background training job. |
| """ |
| job_file = os.path.join(JOBS_DIR, f"{job_id}.json") |
| if not os.path.exists(job_file): |
| return jsonify({"error": "Job not found"}), 404 |
| |
| with open(job_file, 'r') as f: |
| job_data = json.load(f) |
| return jsonify(job_data) |
|
|
| @app.route('/api/ask', methods=['POST']) |
| def ask_question_endpoint(): |
| """ |
| Answers a question using the currently loaded FAISS index. |
| """ |
| data = request.get_json() |
| query = data.get('query') |
| session_id = data.get('sessionId') |
|
|
| if not all([query, session_id]): |
| return jsonify({"error": "Query and sessionId are required"}), 400 |
|
|
| try: |
| |
| session_path = os.path.join(SESSIONS_DIR, session_id) |
| index_path = os.path.join(session_path, "index.faiss") |
| metadata_path = os.path.join(session_path, "metadata.pkl") |
| if not os.path.exists(index_path): |
| return jsonify({"error": "Invalid or expired session. Please train the chatbot again."}), 404 |
| |
| index = faiss.read_index(index_path) |
| with open(metadata_path, "rb") as f: metadata = pickle.load(f) |
| chunks, chunk_metadata = metadata["chunks"], metadata["chunk_metadata"] |
|
|
| |
| response = "" |
| use_api = GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here" |
|
|
| if use_api: |
| |
| global api_llm |
| if api_llm is None: api_llm = Groq(api_key=GROQ_API_KEY) |
|
|
| |
| query_vec = get_embedding_model().encode([query], show_progress_bar=False) |
| distances, indices = index.search(np.array(query_vec), k=3) |
| context = "\n".join([chunks[i] for i in indices[0]]) |
| sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)])) |
| |
| prompt_template = ("Using ONLY the information from the following context, answer the question. " |
| "Do not mention the context in your answer. Be concise.\n\n" |
| "Context:\n{context}\n\nQuestion: {query}\n\nAnswer:") |
| prompt = prompt_template.format(context=context, query=query) |
|
|
| logging.info("Attempting to generate response with Groq API...") |
| try: |
| chat_completion = api_llm.chat.completions.create( |
| messages=[{"role": "user", "content": prompt}], |
| |
| |
| model="mixtral-8x7b-32768", |
| temperature=0.5, max_tokens=250 |
| ) |
| response = chat_completion.choices[0].message.content |
| logging.info("Groq API call successful.") |
| except Exception as api_e: |
| logging.warning(f"API call failed: {api_e}. Will attempt to fall back to local model.") |
| use_api = False |
|
|
| if not use_api: |
| |
| |
| query_vec = get_embedding_model().encode([query], show_progress_bar=False) |
| distances, indices = index.search(np.array(query_vec), k=1) |
| context = "\n".join([chunks[i] for i in indices[0]]) |
| sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)])) |
|
|
| prompt_template = ("Using ONLY the information from the following context, answer the question. " |
| "Do not mention the context in your answer. Be concise.\n\n" |
| "Context:\n{context}\n\nQuestion: {query}\n\nAnswer:") |
| prompt = prompt_template.format(context=context, query=query) |
|
|
| logging.info("Generating response with local model.") |
| global local_llm |
| if local_llm is None: local_llm = get_local_llm() |
| response = local_llm.generate(prompt, max_tokens=250, temp=0.5) |
| |
| if "LLaMA ERROR" in response: |
| raise RuntimeError("The local model failed to generate a response due to context size limitations.") |
|
|
| return jsonify({"answer": response, "sources": sources}) |
|
|
| except Exception as e: |
| logging.error(f"An unexpected error occurred in /api/ask for session {session_id}: {e}", exc_info=True) |
| return jsonify({"error": "An internal server error occurred. The model may have failed to load or process the request. Check server logs for details."}), 500 |
|
|
| if __name__ == '__main__': |
| app.run(debug=True, port=5001) |
|
|