Digambar29's picture
Changed the model to mixtral-8x7b-32768 for groq
7770bcb
import os
import requests
import pdfplumber
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from gpt4all import GPT4All
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
import logging
from dotenv import load_dotenv
from groq import Groq
import uuid
import pickle
import json
# --- Basic Setup ---
load_dotenv() # Load environment variables from .env file
app = Flask(__name__, static_folder='../', static_url_path='/')
# Enable Cross-Origin Resource Sharing for specific origins.
# This allows your frontend (e.g., from http://127.0.0.1:5500) to talk to the backend.
CORS(app, resources={r"/api/*": {"origins": "*"}})
logging.basicConfig(level=logging.INFO)
# --- Configuration & Global State ---
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
JOBS_DIR = os.path.join(os.path.dirname(__file__), 'jobs')
SESSIONS_DIR = os.path.join(os.path.dirname(__file__), 'sessions')
os.makedirs(SESSIONS_DIR, exist_ok=True)
os.makedirs(JOBS_DIR, exist_ok=True)
embedding_model = None
local_llm = None
api_llm = None
# Create a centralized requests session to reuse connections and headers
http_session = requests.Session()
http_session.headers.update({
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
})
# Define HF_CACHE_DIR using the environment variable, falling back to a default if not set
# In a Docker environment, HF_HOME will be set by the Dockerfile
HF_CACHE_DIR = os.getenv("HF_HOME", os.path.join("/app", ".cache", "huggingface"))
def get_embedding_model():
"""Lazy-loads the embedding model to speed up initial server start."""
global embedding_model
if embedding_model is None:
logging.info("Loading embedding model for the first time...")
try:
# Explicitly pass the cache_folder to ensure it uses the writable directory
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=HF_CACHE_DIR)
logging.info(f"Embedding model loaded. Cache folder: {HF_CACHE_DIR}")
except Exception as e:
logging.error(f"CRITICAL ERROR: Failed to load embedding model from {HF_CACHE_DIR}. "
f"Please check permissions and ensure the model can be downloaded/accessed. Error: {e}")
raise # Re-raise the exception so the calling function (e.g., _run_training_job) can handle it
return embedding_model
def get_local_llm():
"""Lazy-loads the local GPT4All model, downloading if necessary."""
global local_llm
if local_llm is None:
logging.info("Initializing local LLM for fallback...")
MODELS_DIR = os.path.join(os.path.dirname(__file__), 'models')
MODEL_NAME = "orca-mini-3b-gguf2-q4_0.gguf"
GPT_MODEL_PATH = os.path.join(MODELS_DIR, MODEL_NAME)
# In a containerized environment, the model is pre-downloaded during the build.
# This check is a safeguard for local development.
if not os.path.exists(GPT_MODEL_PATH):
error_message = f"Local model not found at {GPT_MODEL_PATH}. It should have been included in the Docker build."
logging.error(error_message)
# In a production deployment, we should not attempt to download it at runtime.
raise FileNotFoundError(error_message)
logging.info(f"Loading GPT4All model from: {GPT_MODEL_PATH}")
local_llm = GPT4All(GPT_MODEL_PATH)
logging.info("Local LLM loaded.")
return local_llm
# --- Helper Functions ---
def chunk_text(text, chunk_size=500, overlap=50):
words = text.split()
return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size - overlap)]
# --- API Endpoints ---
@app.route('/')
def serve_index():
return send_from_directory(app.static_folder, 'index.html')
@app.route('/api/search', methods=['POST'])
def search_nasa_api():
"""
Searches the NASA NTRS API for papers.
"""
data = request.get_json()
keyword = data.get('keyword')
if not keyword:
return jsonify({"error": "Keyword is required"}), 400
logging.info(f"Searching NTRS for: {keyword}")
# Using the public NTRS API endpoint
search_url = f"https://ntrs.nasa.gov/api/citations/search?q={keyword}&limit=15"
try:
response = http_session.get(search_url)
response.raise_for_status()
results = response.json().get('results', [])
# Filter and format results to include only what we need
formatted_results = []
for item in results:
# Safely find the download link for the PDF, checking that 'mimeType' exists.
# Correctly parse the nested link structure based on the actual API response.
pdf_link = next((d.get('links', {}).get('pdf') for d in item.get('downloads', []) if d.get('mimetype') == 'application/pdf'), None)
if pdf_link:
formatted_results.append({
'id': item.get('id'),
'title': item.get('title'),
'abstract': item.get('abstract'),
'pdfLink': f"https://ntrs.nasa.gov{pdf_link}"
})
return jsonify(formatted_results)
except requests.RequestException as e:
logging.error(f"NTRS API request failed: {e}")
return jsonify({"error": "Failed to fetch data from NASA NTRS API"}), 500
def _run_training_job(job_id, papers):
"""The actual long-running training logic, using GCS if configured."""
job_file = os.path.join(JOBS_DIR, f"{job_id}.json")
def update_status(status, message, extra_data=None):
"""Updates job status in a file (local or GCS)."""
job_data = {"status": status, "message": message, **(extra_data or {})}
with open(job_file, 'w') as f: json.dump(job_data, f)
def save_session_data(session_id, index, metadata):
"""Saves FAISS index and metadata to GCS or local."""
session_path = os.path.join(SESSIONS_DIR, session_id)
os.makedirs(session_path, exist_ok=True)
faiss.write_index(index, os.path.join(session_path, "index.faiss"))
with open(os.path.join(session_path, "metadata.pkl"), "wb") as f:
pickle.dump(metadata, f)
try:
update_status("processing", f"Training on {len(papers)} papers...")
logging.info(f"Starting background training for job {job_id} on {len(papers)} papers.")
if not papers:
raise ValueError("List of papers is required for training.")
# Use a temporary directory for PDF downloads
papers_folder = "papers"
os.makedirs(papers_folder, exist_ok=True)
all_chunks = []
chunk_metadata = []
for i, paper in enumerate(papers):
url = paper.get('url') # Correctly get the 'url' key sent from the frontend
paper_id = paper.get('id')
paper_title = paper.get('title', 'Unknown Title')
# Update progress
update_status("processing", f"Processing paper {i+1}/{len(papers)}: {paper_title[:30]}...")
try:
filename = f"{paper_id}.pdf" # Use unique ID for filename
filepath = os.path.join(papers_folder, filename)
# Download the paper if it doesn't exist
if not os.path.exists(filepath):
logging.info(f"Downloading {url}...")
response = http_session.get(url, stream=True)
response.raise_for_status()
with open(filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# Extract text
with pdfplumber.open(filepath) as pdf:
text = "".join(page.extract_text() + "\n" for page in pdf.pages if page.extract_text())
# Chunk the text and associate metadata immediately to save memory
paper_chunks = chunk_text(text)
all_chunks.extend(paper_chunks)
chunk_metadata.extend([{'id': paper_id, 'title': paper_title}] * len(paper_chunks))
except Exception as e:
logging.error(f"Failed to process paper {url}: {e}")
continue # Skip failed papers
if not all_chunks:
raise ValueError("Could not process any of the selected papers.")
# Chunk, embed, and create FAISS index
update_status("processing", "Embedding text... This is the slowest step.")
# Use multi-process encoding for a significant speed-up on multi-core CPUs
embeddings = get_embedding_model().encode(all_chunks, show_progress_bar=False, batch_size=32)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings))
# Generate a unique session ID and store the state
session_id = str(uuid.uuid4())
save_session_data(session_id, index, {"chunks": all_chunks, "chunk_metadata": chunk_metadata})
logging.info(f"Training complete for job {job_id}. Session saved to {session_id}.")
# Mark job as complete
update_status("complete", f"Chatbot is ready! Trained on {len(papers)} papers.", {"sessionId": session_id})
except Exception as e:
logging.error(f"Background job {job_id} failed: {e}")
update_status("error", str(e))
@app.route('/api/train', methods=['POST'])
def train_chatbot():
"""
Kicks off the training process in a background thread.
"""
data = request.get_json()
papers = data.get('papers')
if not papers:
return jsonify({"error": "List of papers is required"}), 400
job_id = str(uuid.uuid4())
initial_status = {"status": "starting", "message": "Initializing training job..."}
job_file = os.path.join(JOBS_DIR, f"{job_id}.json")
with open(job_file, 'w') as f:
json.dump(initial_status, f)
# Run the long task in a background thread
import threading
thread = threading.Thread(target=_run_training_job, args=(job_id, papers))
thread.start()
# Immediately return a response to the client
return jsonify({"message": "Training started.", "jobId": job_id}), 202
@app.route('/api/train/status/<job_id>', methods=['GET'])
def get_training_status(job_id):
"""
Checks the status of a background training job.
"""
job_file = os.path.join(JOBS_DIR, f"{job_id}.json")
if not os.path.exists(job_file):
return jsonify({"error": "Job not found"}), 404
with open(job_file, 'r') as f:
job_data = json.load(f)
return jsonify(job_data)
@app.route('/api/ask', methods=['POST'])
def ask_question_endpoint():
"""
Answers a question using the currently loaded FAISS index.
"""
data = request.get_json()
query = data.get('query')
session_id = data.get('sessionId')
if not all([query, session_id]):
return jsonify({"error": "Query and sessionId are required"}), 400
try:
# --- 1. Load Session Data ---
session_path = os.path.join(SESSIONS_DIR, session_id)
index_path = os.path.join(session_path, "index.faiss")
metadata_path = os.path.join(session_path, "metadata.pkl")
if not os.path.exists(index_path):
return jsonify({"error": "Invalid or expired session. Please train the chatbot again."}), 404
index = faiss.read_index(index_path)
with open(metadata_path, "rb") as f: metadata = pickle.load(f)
chunks, chunk_metadata = metadata["chunks"], metadata["chunk_metadata"]
# --- 4. Generate Response (with API/Local Fallback) ---
response = ""
use_api = GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here"
if use_api:
# --- Primary Method: Try Groq API ---
global api_llm
if api_llm is None: api_llm = Groq(api_key=GROQ_API_KEY)
# Retrieve top 3 chunks for the powerful API model
query_vec = get_embedding_model().encode([query], show_progress_bar=False)
distances, indices = index.search(np.array(query_vec), k=3)
context = "\n".join([chunks[i] for i in indices[0]])
sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)]))
prompt_template = ("Using ONLY the information from the following context, answer the question. "
"Do not mention the context in your answer. Be concise.\n\n"
"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
prompt = prompt_template.format(context=context, query=query)
logging.info("Attempting to generate response with Groq API...")
try:
chat_completion = api_llm.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
# Use a stable, powerful model like Mixtral.
# Other options include 'gemma-7b-it'.
model="mixtral-8x7b-32768",
temperature=0.5, max_tokens=250
)
response = chat_completion.choices[0].message.content
logging.info("Groq API call successful.")
except Exception as api_e:
logging.warning(f"API call failed: {api_e}. Will attempt to fall back to local model.")
use_api = False # Force fallback
if not use_api:
# --- Fallback Method: Use Local Model ---
# The local model has a small context window, so we only use the single most relevant chunk (k=1).
query_vec = get_embedding_model().encode([query], show_progress_bar=False)
distances, indices = index.search(np.array(query_vec), k=1)
context = "\n".join([chunks[i] for i in indices[0]])
sources = list(set([chunk_metadata[i]['title'] for i in indices[0] if i < len(chunk_metadata)]))
prompt_template = ("Using ONLY the information from the following context, answer the question. "
"Do not mention the context in your answer. Be concise.\n\n"
"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
prompt = prompt_template.format(context=context, query=query)
logging.info("Generating response with local model.")
global local_llm
if local_llm is None: local_llm = get_local_llm()
response = local_llm.generate(prompt, max_tokens=250, temp=0.5)
# Check if the local model returned an error message instead of an answer
if "LLaMA ERROR" in response:
raise RuntimeError("The local model failed to generate a response due to context size limitations.")
return jsonify({"answer": response, "sources": sources})
except Exception as e:
logging.error(f"An unexpected error occurred in /api/ask for session {session_id}: {e}", exc_info=True)
return jsonify({"error": "An internal server error occurred. The model may have failed to load or process the request. Check server logs for details."}), 500
if __name__ == '__main__':
app.run(debug=True, port=5001)