File size: 4,088 Bytes
5785ed4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os, glob, json, faiss, numpy as np
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from groq import Groq
from src.config import *
# Ensure directories exist
os.makedirs(INDEX_DIR, exist_ok=True)
# Initialize models
embedder = SentenceTransformer(EMBEDDING_MODEL)
summarizer = pipeline("summarization", model=SUMMARIZER_MODEL)
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
# --- Token Counter ---
try:
import tiktoken
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
def count_tokens(text): return len(enc.encode(text))
except Exception:
def count_tokens(text): return len(text) // 4
# --- Build Index ---
def build_index():
index = faiss.IndexFlatIP(384)
meta = []
def chunk_text(text, size=800, overlap=120):
chunks = []
i = 0
while i < len(text):
chunks.append(text[i:i+size].strip())
i += size - overlap
return chunks
for domain_dir in glob.glob(f"{RAW_DIR}/*"):
domain = os.path.basename(domain_dir)
for path in glob.glob(f"{domain_dir}/*.txt"):
with open(path, encoding="utf-8") as f:
text = f.read()
chunks = chunk_text(text)
vecs = embedder.encode(chunks, normalize_embeddings=True)
index.add(np.array(vecs).astype("float32"))
for ch in chunks:
meta.append({"domain": domain, "text": ch, "source": os.path.basename(path)})
print(f"β
Indexed {domain}/{os.path.basename(path)} ({len(chunks)} chunks)")
faiss.write_index(index, INDEX_PATH)
json.dump(meta, open(META_PATH, "w"))
print(f"π Index built: {len(meta)} chunks total.")
return index, meta
# Load or build index
if not os.path.exists(INDEX_PATH):
index, meta = build_index()
else:
index = faiss.read_index(INDEX_PATH)
meta = json.load(open(META_PATH))
# --- Retrieval ---
def retrieve_text(query, topk=TOP_K_RESULTS):
qvec = embedder.encode([query], normalize_embeddings=True).astype("float32")
D, I = index.search(qvec, topk)
return [meta[i] for i in I[0]]
# --- Token limiter ---
def trim_to_token_limit(text, max_tokens=MAX_TOKENS):
tokens = count_tokens(text)
if tokens > max_tokens:
print(f"β οΈ Context too long ({tokens}). Trimming...")
cutoff_ratio = max_tokens / tokens
text = text[:int(len(text) * cutoff_ratio)]
return text
# --- Main Answer Generator ---
def generate_answer(query, mode):
retrieved = retrieve_text(query)
combined = " ".join([r["text"] for r in retrieved])
safe_context = trim_to_token_limit(combined)
if mode == "Quick Summary (Offline)":
summary = summarizer(safe_context, max_length=180, min_length=60, do_sample=False)[0]["summary_text"]
else:
prompt = f"""
You are MindMesh, a cross-domain reasoning assistant.
Question: {query}
Context: {safe_context}
Synthesize a precise and insightful answer across disciplines.
"""
try:
response = client.chat.completions.create(
model=PRIMARY_GROQ_MODEL,
messages=[{"role": "user", "content": prompt}],
)
summary = response.choices[0].message.content.strip()
except Exception as e:
try:
response = client.chat.completions.create(
model=FALLBACK_GROQ_MODEL,
messages=[{"role": "user", "content": prompt}],
)
summary = response.choices[0].message.content.strip()
except Exception as e2:
summary = f"β οΈ Groq API error: {str(e2)}"
md = f"## π§ Synthesized Insight\n{summary}\n\n---\n### π Source Highlights\n"
for r in retrieved:
md += f"**{r['domain'].title()} β {r['source']}** \n{r['text'][:300]}...\n\n"
return md
# --- Rebuild Index with Feedback ---
def rebuild():
yield "βοΈ Rebuilding FAISS index... please wait β³"
build_index()
yield "β
Index rebuilt successfully! (FAISS + metadata updated)" |