retriever / app.py
ntdservices's picture
Update app.py
d555613 verified
# app.py
from flask import Flask, request, render_template, send_file, redirect, url_for, jsonify
import os
import re
import uuid
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
# Optional NLI (not required for this feature)
try:
from transformers import pipeline as hf_pipeline
nli = hf_pipeline("text-classification", model="microsoft/deberta-large-mnli")
print("βœ… NLI pipeline loaded.")
except Exception as e:
nli = None
print("ℹ️ NLI pipeline not loaded (optional):", e)
print("⏳ Loading SentenceTransformer model...")
model = SentenceTransformer('all-MiniLM-L6-v2')
print("βœ… Encoder loaded.")
app = Flask(__name__)
# ── base folders ───────────────────────────────────────────────────────────────
BASE_DIR = os.path.dirname(__file__)
BASE_UPLOADS = os.path.join(BASE_DIR, "uploads")
BASE_RESULTS = os.path.join(BASE_DIR, "results")
os.makedirs(BASE_UPLOADS, exist_ok=True)
os.makedirs(BASE_RESULTS, exist_ok=True)
# ── clear uploads at launch ────────────────────────────────────────────────────
def clear_uploads_folder():
"""Remove all files and subfolders inside the uploads folder on app launch."""
for entry in os.listdir(BASE_UPLOADS):
path = os.path.join(BASE_UPLOADS, entry)
if os.path.isdir(path):
for root, dirs, files in os.walk(path, topdown=False):
for fname in files:
os.remove(os.path.join(root, fname))
for dname in dirs:
os.rmdir(os.path.join(root, dname))
os.rmdir(path)
else:
os.remove(path)
clear_uploads_folder()
print("βœ… Uploads folder cleared.")
# runtime cache keyed by search-id:
# (paragraphs_norm, embeddings, faiss-index, spans, para_file_idx, file_meta)
# spans[i] = (start_char, end_char) of paragraph i within merged.txt
# para_file_idx[i] = index into file_meta for paragraph i
# file_meta[j] = {"name": filename, "start": start_char, "end": end_char, "second_line": str}
index_data = {}
# ── helpers ────────────────────────────────────────────────────────────────────
def get_paths(sid: str):
"""Return per-search folders & files, creating them if needed."""
up_folder = os.path.join(BASE_UPLOADS, sid)
res_folder = os.path.join(BASE_RESULTS, sid)
os.makedirs(up_folder, exist_ok=True)
os.makedirs(res_folder, exist_ok=True)
merged_file = os.path.join(res_folder, "merged.txt")
result_file = os.path.join(res_folder, "results.txt")
return up_folder, res_folder, merged_file, result_file
def compute_second_line(raw_text: str) -> str:
"""Return the 2nd non-empty line if available, else the literal 2nd line, else ''."""
lines = raw_text.splitlines()
non_empty = [ln.strip() for ln in lines if ln.strip() != ""]
if len(non_empty) >= 2:
return non_empty[1]
if len(lines) >= 2:
return lines[1].strip()
return ""
def extract_for_merge_and_second(file_path: str):
"""
Return a tuple (merged_text_piece, second_line_str) for a single file.
- For .txt: merged part is raw file text.
- For .pdf: merged part is lightly cleaned text; second line is computed
using raw extracted line structure as well.
"""
if file_path.lower().endswith(".txt"):
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
raw = f.read()
second = compute_second_line(raw)
return raw, second
if file_path.lower().endswith(".pdf"):
reader = PdfReader(file_path)
pages = []
for page in reader.pages:
t = page.extract_text() or ""
pages.append(t)
raw_lines_joined = "\n".join(pages) # preserve some line structure for 2nd line
second = compute_second_line(raw_lines_joined)
# Light cleanup for the merged view
full_text = " ".join(pages)
full_text = re.sub(r'(?<=[.!?])\s{2,}', '\n\n', full_text)
full_text = re.sub(r'(\n\s*){2,}', '\n\n', full_text)
return full_text, second
return "", ""
def split_paragraphs_with_spans(merged_text: str):
"""
Split merged_text into logical 'paragraphs' based on blank lines,
returning normalized paragraphs for embedding AND exact spans (start,end)
in the original merged_text for highlighting/jumping.
"""
sep = re.compile(r'(?:\r?\n)+', flags=re.MULTILINE)
paras_norm = []
spans = []
pos = 0
for m in sep.finditer(merged_text):
seg = merged_text[pos:m.start()]
norm = re.sub(r'\s+', ' ', seg).strip()
if len(norm.split()) > 4: # keep only substantive chunks
paras_norm.append(norm)
spans.append((pos, m.start()))
pos = m.end()
# Tail
seg = merged_text[pos:]
norm = re.sub(r'\s+', ' ', seg).strip()
if len(norm.split()) > 4:
paras_norm.append(norm)
spans.append((pos, len(merged_text)))
return paras_norm, spans
def rebuild_merged_and_index(sid: str):
"""Build merged.txt, paragraph embeddings, spans, and per-paragraph file mapping."""
up_folder, res_folder, merged_file, _ = get_paths(sid)
merged_text = ""
file_meta = [] # list of dicts with name, start, end, second_line
# Append files in sorted order for stability
for filename in sorted(os.listdir(up_folder)):
if not filename.lower().endswith((".pdf", ".txt")):
continue
file_path = os.path.join(up_folder, filename)
part, second = extract_for_merge_and_second(file_path)
part = part.rstrip()
if not part:
continue
start = len(merged_text)
merged_text += part + "\n\n" # add separator after each file
end = start + len(part) # char range covering the file's text (exclude our extra \n\n)
file_meta.append({"name": filename, "start": start, "end": end, "second_line": second})
with open(merged_file, "w", encoding='utf-8') as f:
f.write(merged_text)
paras_norm, spans = split_paragraphs_with_spans(merged_text)
if not paras_norm:
index_data[sid] = ([], None, None, [], [], [])
return
# Map each paragraph span to its originating file via start position
para_file_idx = []
for (pstart, _pend) in spans:
assigned = None
for j, meta in enumerate(file_meta):
next_start = file_meta[j+1]["start"] if j + 1 < len(file_meta) else float("inf")
if pstart >= meta["start"] and pstart < next_start:
assigned = j
break
if assigned is None:
assigned = max(0, len(file_meta)-1)
para_file_idx.append(assigned)
# Build embeddings + FAISS
embed = model.encode(paras_norm, batch_size=32, show_progress_bar=False)
embed = np.asarray(embed)
if embed.ndim == 1:
embed = embed[np.newaxis, :]
faiss.normalize_L2(embed)
idx = faiss.IndexFlatIP(embed.shape[1])
idx.add(embed)
index_data[sid] = (paras_norm, embed, idx, spans, para_file_idx, file_meta)
# ── routes ─────────────────────────────────────────────────────────────────────
@app.route("/", methods=["GET", "POST"])
def index():
sid = request.args.get("sid") or request.form.get("sid")
if not sid:
sid = str(uuid.uuid4())
up_folder, _, _, _ = get_paths(sid)
paragraphs, embeddings, index_faiss, spans, para_file_idx, file_meta = index_data.get(
sid, ([], None, None, [], [], [])
)
uploaded_filenames = sorted(os.listdir(up_folder))
results = []
query = ""
k = 5
if request.method == "POST":
query = request.form.get("query", "").strip()
try:
k = int(request.form.get("topk", 5))
except ValueError:
k = 5
if paragraphs and query:
q_embed = model.encode([query])
q_embed = np.asarray(q_embed)
if q_embed.ndim == 1:
q_embed = q_embed[np.newaxis, :]
faiss.normalize_L2(q_embed)
D, I = index_faiss.search(q_embed, k=min(k, len(paragraphs)))
# Build result objects with file name + the file's 2nd line
for i in I[0]:
i = int(i)
file_idx = para_file_idx[i] if 0 <= i < len(para_file_idx) else -1
fname = file_meta[file_idx]["name"] if 0 <= file_idx < len(file_meta) else "unknown"
second = file_meta[file_idx]["second_line"] if 0 <= file_idx < len(file_meta) else ""
results.append({
"idx": i,
"text": paragraphs[i],
"file": fname,
"second_line": second
})
_, res_folder, _, result_file = get_paths(sid)
with open(result_file, "w", encoding='utf-8') as f:
for r in results:
f.write(r["text"] + "\n\n")
return render_template(
"index.html",
results=results,
query=query,
topk=k,
sid=sid,
uploaded_filenames=uploaded_filenames
)
@app.route("/upload", methods=["POST"])
def upload_file():
sid = request.args.get("sid")
if not sid:
return ("Missing sid", 400)
up_folder, _, _, _ = get_paths(sid)
uploaded_files = request.files.getlist("file")
for file in uploaded_files:
if file and file.filename.lower().endswith((".pdf", ".txt")):
file.save(os.path.join(up_folder, file.filename))
rebuild_merged_and_index(sid)
return ("", 204)
@app.route("/download")
def download():
sid = request.args.get("sid")
if not sid:
return ("Missing sid", 400)
_, _, _, result_file = get_paths(sid)
if not os.path.exists(result_file):
return ("Nothing to download", 404)
return send_file(result_file, as_attachment=True)
@app.route("/download_merged")
def download_merged():
sid = request.args.get("sid")
if not sid:
return ("Missing sid", 400)
_, _, merged_file, _ = get_paths(sid)
if not os.path.exists(merged_file):
return ("Nothing to download", 404)
return send_file(merged_file, as_attachment=True)
@app.route("/reset")
def reset():
sid = request.args.get("sid")
if not sid:
return redirect(url_for('index'))
up_folder, res_folder, _, _ = get_paths(sid)
for folder in [up_folder, res_folder]:
if os.path.exists(folder):
for f in os.listdir(folder):
os.remove(os.path.join(folder, f))
index_data.pop(sid, None) # drop cached embeddings
return redirect(url_for('index'))
@app.route("/api/ping")
def ping():
return "pong", 200
@app.route("/api/context")
def api_context():
"""
Return FULL merged text plus the exact character span for the requested paragraph.
Query params: sid, idx (int)
Response: { merged: str, start: int, end: int, total_len: int }
"""
sid = request.args.get("sid")
if not sid:
return jsonify(error="Missing sid"), 400
try:
idx = int(request.args.get("idx", "-1"))
except (TypeError, ValueError):
return jsonify(error="Bad idx"), 400
paragraphs, _, _, spans, _, _ = index_data.get(sid, (None, None, None, None, None, None))
if paragraphs is None or spans is None:
return jsonify(error="No index for this sid. Upload files first."), 404
if not (0 <= idx < len(spans)):
return jsonify(error="idx out of range"), 400
_, _, merged_file, _ = get_paths(sid)
if not os.path.exists(merged_file):
return jsonify(error="merged.txt not found"), 404
with open(merged_file, "r", encoding="utf-8") as f:
merged_text = f.read()
start, end = spans[idx]
return jsonify(
merged=merged_text,
start=start,
end=end,
total_len=len(merged_text)
)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)