Spaces:
Runtime error
Runtime error
| # Standard library | |
| import logging | |
| import os | |
| import shutil | |
| import tempfile | |
| import traceback | |
| import zipfile | |
| from contextlib import asynccontextmanager | |
| from functools import lru_cache | |
| from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple | |
| # Third-party | |
| import aiofiles | |
| import faiss | |
| import gcsfs | |
| import polars as pl | |
| import pickle | |
| import torch | |
| from tqdm import tqdm | |
| from fastapi import FastAPI, HTTPException, Request, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, PrivateAttr | |
| from pydantic_settings import BaseSettings | |
| from sentence_transformers import CrossEncoder | |
| from starlette.concurrency import run_in_threadpool | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| T5ForConditionalGeneration, | |
| T5Tokenizer, | |
| pipeline, | |
| ) | |
| from whoosh import index | |
| from whoosh.analysis import StemmingAnalyzer | |
| from whoosh.fields import ID, Schema, TEXT | |
| from whoosh.qparser import MultifieldParser | |
| # LangChain | |
| from langchain.schema import BaseRetriever, Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain.retrievers.document_compressors import DocumentCompressorPipeline | |
| from langchain_community.document_transformers import EmbeddingsRedundantFilter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings | |
| # === Logging === | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Settings(BaseSettings): | |
| # Parquet + Whoosh/FAISS | |
| parquet_path: str = "gs://mda_kul_project/data/consolidated_clean_pred.parquet" | |
| whoosh_dir: str = "gs://mda_kul_project/whoosh_index" | |
| vectorstore_path: str = "gs://mda_kul_project/vectorstore_index" | |
| # Models | |
| embedding_model: str = "sentence-transformers/LaBSE" | |
| llm_model: str = "google/flmt5-base" | |
| cross_encoder_model: str = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" | |
| # RAG parameters | |
| chunk_size: int = 750 | |
| chunk_overlap: int = 100 | |
| hybrid_k: int = 2 | |
| assistant_role: str = ( | |
| "You are a knowledgeable project analyst. You have access to the following retrieved document snippets." | |
| ) | |
| skip_warmup: bool = True | |
| allowed_origins: List[str] = ["*"] | |
| class Config: | |
| env_file = ".env" | |
| settings = Settings() | |
| # === Global Embeddings & Cache === | |
| EMBEDDING = HuggingFaceEmbeddings(model_name=settings.embedding_model) | |
| def embed_query_cached(query: str) -> List[float]: | |
| """Cache embedding vectors for queries.""" | |
| return EMBEDDING.embed_query(query.strip().lower()) | |
| # === Whoosh Cache & Builder === | |
| async def build_whoosh_index(docs: List[Document], whoosh_dir: str) -> index.Index: | |
| """ | |
| If gs://.../whoosh_index.zip exists, download & extract it once. | |
| Otherwise build locally from docs and upload the ZIP back to GCS. | |
| """ | |
| fs = gcsfs.GCSFileSystem() | |
| is_gcs = whoosh_dir.startswith("gs://") | |
| zip_uri = whoosh_dir.rstrip("/") + ".zip" | |
| local_zip = "/tmp/whoosh_index.zip" | |
| local_dir = "/tmp/whoosh_index" | |
| # Clean slate | |
| if os.path.exists(local_dir): | |
| shutil.rmtree(local_dir) | |
| os.makedirs(local_dir, exist_ok=True) | |
| # 1️⃣ Try downloading the ZIP if it exists on GCS | |
| if is_gcs and await run_in_threadpool(fs.exists, zip_uri): | |
| logger.info("Found whoosh_index.zip on GCS; downloading…") | |
| await run_in_threadpool(fs.get, zip_uri, local_zip) | |
| # Extract all files (flat) into local_dir | |
| with zipfile.ZipFile(local_zip, "r") as zf: | |
| for member in zf.infolist(): | |
| if member.is_dir(): | |
| continue | |
| filename = os.path.basename(member.filename) | |
| if not filename: | |
| continue | |
| target = os.path.join(local_dir, filename) | |
| os.makedirs(os.path.dirname(target), exist_ok=True) | |
| with zf.open(member) as src, open(target, "wb") as dst: | |
| dst.write(src.read()) | |
| logger.info("Whoosh index extracted from ZIP.") | |
| else: | |
| logger.info("No whoosh_index.zip found; building index from docs.") | |
| # Define the schema with stored content | |
| schema = Schema( | |
| id=ID(stored=True, unique=True), | |
| content=TEXT(stored=True, analyzer=StemmingAnalyzer()), | |
| ) | |
| # Create the index | |
| ix = index.create_in(local_dir, schema) | |
| writer = ix.writer() | |
| for doc in docs: | |
| writer.add_document( | |
| id=doc.metadata.get("id", ""), | |
| content=doc.page_content, | |
| ) | |
| writer.commit() | |
| logger.info("Whoosh index built locally.") | |
| # Upload the ZIP back to GCS | |
| if is_gcs: | |
| logger.info("Zipping and uploading new whoosh_index.zip to GCS…") | |
| with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf: | |
| for root, _, files in os.walk(local_dir): | |
| for fname in files: | |
| full = os.path.join(root, fname) | |
| arc = os.path.relpath(full, local_dir) | |
| zf.write(full, arc) | |
| await run_in_threadpool(fs.put, local_zip, zip_uri) | |
| logger.info("Uploaded whoosh_index.zip to GCS.") | |
| # 2️⃣ Finally open the index and return it | |
| ix = index.open_dir(local_dir) | |
| return ix | |
| # === Document Loader === | |
| async def load_documents( | |
| path: str, | |
| sample_size: Optional[int] = None | |
| ) -> List[Document]: | |
| """ | |
| Load project data from a Parquet file (local path or GCS URI), | |
| assemble metadata context for each row, and return as Document objects. | |
| """ | |
| def _read_local(p: str, n: Optional[int]): | |
| # streaming scan keeps memory low | |
| lf = pl.scan_parquet(p) | |
| if n: | |
| lf = lf.limit(n) | |
| return lf.collect(streaming=True) | |
| def _read_gcs(p: str, n: Optional[int]): | |
| # download to a temp file synchronously, then read with Polars | |
| fs = gcsfs.GCSFileSystem() | |
| with tempfile.TemporaryDirectory() as td: | |
| local_path = os.path.join(td, "data.parquet") | |
| fs.get(p, local_path, recursive=False) | |
| df = pl.read_parquet(local_path) | |
| if n: | |
| df = df.head(n) | |
| return df | |
| try: | |
| if path.startswith("gs://"): | |
| df = await run_in_threadpool(_read_gcs, path, sample_size) | |
| else: | |
| df = await run_in_threadpool(_read_local, path, sample_size) | |
| except Exception as e: | |
| logger.error(f"Error loading documents: {e}") | |
| raise HTTPException(status_code=500, detail="Document loading failed.") | |
| docs: List[Document] = [] | |
| for row in df.rows(named=True): | |
| context_parts: List[str] = [] | |
| # build metadata context | |
| max_contrib = row.get("ecMaxContribution", "") | |
| end_date = row.get("endDate", "") | |
| duration = row.get("durationDays", "") | |
| status = row.get("status", "") | |
| legal = row.get("legalBasis", "") | |
| framework = row.get("frameworkProgramme", "") | |
| scheme = row.get("fundingScheme", "") | |
| names = row.get("list_name", []) or [] | |
| cities = row.get("list_city", []) or [] | |
| countries = row.get("list_country", []) or [] | |
| activity = row.get("list_activityType", []) or [] | |
| contributions = row.get("list_ecContribution", []) or [] | |
| smes = row.get("list_sme", []) or [] | |
| project_id =row.get("id", "") | |
| pred=row.get("predicted_label", "") | |
| proba=row.get("predicted_prob", "") | |
| top1_feats=row.get("top1_features", "") | |
| top2_feats=row.get("top2_features", "") | |
| top3_feats=row.get("top3_features", "") | |
| top1_shap=row.get("top1_shap", "") | |
| top2_shap=row.get("top2_shap", "") | |
| top3_shap=row.get("top3_shap", "") | |
| context_parts.append( | |
| f"This project under framework {framework} with funding scheme {scheme}, status {status}, legal basis {legal}." | |
| ) | |
| context_parts.append( | |
| f"It ends on {end_date} after {duration} days and has a max EC contribution of {max_contrib}." | |
| ) | |
| context_parts.append("Participating organizations:") | |
| for i, name in enumerate(names): | |
| city = cities[i] if i < len(cities) else "" | |
| country = countries[i] if i < len(countries) else "" | |
| act = activity[i] if i < len(activity) else "" | |
| contrib = contributions[i] if i < len(contributions) else "" | |
| sme_flag = "SME" if (smes and i < len(smes) and smes[i]) else "non-SME" | |
| context_parts.append( | |
| f"- {name} in {city}, {country}, activity: {act}, contributed: {contrib}, {sme_flag}." | |
| ) | |
| if status in (None,"signed","SIGNED","Signed"): | |
| if int(pred) == 1: | |
| label = "TERMINATED" | |
| score = float(proba) | |
| else: | |
| label = "CLOSED" | |
| score = 1 - float(proba) | |
| score_str = f"{score:.2f}" | |
| context_parts.append( | |
| f"- Project {project_id} is predicted to be {label} (score={score_str}). " | |
| f"The 3 most predictive features were: " | |
| f"{top1_feats} ({top1_shap:.3f}), " | |
| f"{top2_feats} ({top2_shap:.3f}), " | |
| f"{top3_feats} ({top3_shap:.3f})." | |
| ) | |
| title_report = row.get("list_title_report", "") | |
| objective = row.get("objective", "") | |
| full_body = f"{title_report} {objective}" | |
| full_text = " ".join(context_parts + [full_body]) | |
| meta: Dict[str, Any] = {"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))} | |
| meta.update({"id": str(row.get("id", "")),"startDate": str(row.get("startDate", "")),"endDate": str(row.get("endDate", "")),"status":str(row.get("status", "")),"legalBasis":str(row.get("legalBasis",""))}) | |
| docs.append(Document(page_content=full_text, metadata=meta)) | |
| return docs | |
| # === BM25 Search === | |
| async def bm25_search(ix: index.Index, query: str, k: int) -> List[Document]: | |
| parser = MultifieldParser(["content"], schema=ix.schema) | |
| def _search() -> List[Document]: | |
| with ix.searcher() as searcher: | |
| hits = searcher.search(parser.parse(query), limit=k) | |
| return [Document(page_content=h["content"], metadata={"id": h["id"]}) for h in hits] | |
| return await run_in_threadpool(_search) | |
| # === Helper: build or load FAISS with mmap === | |
| async def build_or_load_faiss( | |
| docs: List[Document], | |
| vectorstore_path: str, | |
| batch_size: int = 15000 | |
| ) -> FAISS: | |
| """ | |
| Expects a ZIP at vectorstore_path + ".zip" containing: | |
| - index.faiss | |
| - index.pkl | |
| Files may be nested under a subfolder (e.g. vectorstore_index_colab/). | |
| If the ZIP exists on GCS, download & load only. | |
| Otherwise, build from `docs`, save, re-zip, and upload. | |
| """ | |
| fs = gcsfs.GCSFileSystem() | |
| is_gcs = vectorstore_path.startswith("gs://") | |
| zip_uri = vectorstore_path.rstrip("/") + ".zip" | |
| local_zip = "/tmp/faiss_index.zip" | |
| local_dir = "/tmp/faiss_store" | |
| # 1) if ZIP exists, download & extract | |
| if is_gcs and await run_in_threadpool(fs.exists, zip_uri): | |
| logger.info("Found FAISS ZIP on GCS; loading only.") | |
| # clean slate | |
| if os.path.exists(local_dir): | |
| shutil.rmtree(local_dir) | |
| os.makedirs(local_dir, exist_ok=True) | |
| # download zip | |
| await run_in_threadpool(fs.get, zip_uri, local_zip) | |
| # extract | |
| def _extract(): | |
| with zipfile.ZipFile(local_zip, "r") as zf: | |
| zf.extractall(local_dir) | |
| await run_in_threadpool(_extract) | |
| # locate the two files anywhere under local_dir | |
| idx_path = None | |
| meta_path = None | |
| for root, _, files in os.walk(local_dir): | |
| if "index.faiss" in files: | |
| idx_path = os.path.join(root, "index.faiss") | |
| if "index.pkl" in files: | |
| meta_path = os.path.join(root, "index.pkl") | |
| if not idx_path or not meta_path: | |
| raise FileNotFoundError("Couldn't find index.faiss or index.pkl in extracted ZIP.") | |
| # memory-map load | |
| mmap_index = await run_in_threadpool( | |
| faiss.read_index, idx_path, faiss.IO_FLAG_MMAP | |
| ) | |
| # load metadata | |
| with open(meta_path, "rb") as f: | |
| saved = pickle.load(f) | |
| # unpack metadata | |
| if isinstance(saved, tuple): | |
| _, docstore, index_to_docstore = ( | |
| saved if len(saved) == 3 else (None, *saved) | |
| ) | |
| else: | |
| docstore = getattr(saved, "docstore", saved._docstore) | |
| index_to_docstore = getattr( | |
| saved, | |
| "index_to_docstore", | |
| getattr(saved, "_index_to_docstore", saved._faiss_index_to_docstore) | |
| ) | |
| # reconstruct FAISS | |
| vs = FAISS( | |
| embedding_function=EMBEDDING, | |
| index=mmap_index, | |
| docstore=docstore, | |
| index_to_docstore_id=index_to_docstore, | |
| ) | |
| logger.info("FAISS index loaded from ZIP.") | |
| return vs | |
| # 2) otherwise, build from scratch and upload | |
| logger.info("No FAISS ZIP found; building index from scratch.") | |
| if os.path.exists(local_dir): | |
| shutil.rmtree(local_dir) | |
| os.makedirs(local_dir, exist_ok=True) | |
| vs: FAISS = None | |
| for i in range(0, len(docs), batch_size): | |
| batch = docs[i : i + batch_size] | |
| if vs is None: | |
| vs = FAISS.from_documents(batch, EMBEDDING) | |
| else: | |
| vs.add_documents(batch) | |
| assert vs is not None, "No documents to index!" | |
| # save locally | |
| vs.save_local(local_dir) | |
| if is_gcs: | |
| # re-zip all contents of local_dir (flattened) | |
| def _zip_dir(): | |
| with zipfile.ZipFile(local_zip, "w", zipfile.ZIP_DEFLATED) as zf: | |
| for root, _, files in os.walk(local_dir): | |
| for fname in files: | |
| full = os.path.join(root, fname) | |
| arc = os.path.relpath(full, local_dir) | |
| zf.write(full, arc) | |
| await run_in_threadpool(_zip_dir) | |
| await run_in_threadpool(fs.put, local_zip, zip_uri) | |
| logger.info("Built FAISS index and uploaded ZIP to GCS.") | |
| return vs | |
| # === Index Builder === | |
| async def build_indexes( | |
| parquet_path: str, | |
| vectorstore_path: str, | |
| whoosh_dir: str, | |
| chunk_size: int, | |
| chunk_overlap: int, | |
| debug_size: Optional[int] | |
| ) -> Tuple[FAISS, index.Index]: | |
| """ | |
| Load documents, build/load Whoosh and FAISS indices, and return both. | |
| """ | |
| docs = await load_documents(parquet_path, debug_size) | |
| ix = await build_whoosh_index(docs, whoosh_dir) | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| # build or load (with mmap) FAISS | |
| vs = await build_or_load_faiss(chunks, vectorstore_path) | |
| return vs, ix | |
| # === Hybrid Retriever === | |
| class HybridRetriever(BaseRetriever): | |
| """Hybrid retriever combining BM25 and FAISS with cross-encoder re-ranking.""" | |
| # store FAISS and Whoosh under private attributes to avoid Pydantic field errors | |
| _vs: FAISS = PrivateAttr() | |
| _ix: index.Index = PrivateAttr() | |
| _compressor: DocumentCompressorPipeline = PrivateAttr() | |
| _cross_encoder: CrossEncoder = PrivateAttr() | |
| def __init__( | |
| self, | |
| vs: FAISS, | |
| ix: index.Index, | |
| compressor: DocumentCompressorPipeline, | |
| cross_encoder: CrossEncoder | |
| ) -> None: | |
| super().__init__() | |
| object.__setattr__(self, '_vs', vs) | |
| object.__setattr__(self, '_ix', ix) | |
| object.__setattr__(self, '_compressor', compressor) | |
| object.__setattr__(self, '_cross_encoder', cross_encoder) | |
| async def _aget_relevant_documents(self, query: str) -> List[Document]: | |
| # BM25 retrieval using Whoosh index | |
| bm_docs = await bm25_search(self._ix, query, settings.hybrid_k) | |
| # Dense retrieval using FAISS | |
| dense_docs = self._vs.similarity_search_by_vector( | |
| embed_query_cached(query), k=settings.hybrid_k | |
| ) | |
| # Cross-encoder re-ranking | |
| candidates = bm_docs + dense_docs | |
| scores = self._cross_encoder.predict([ | |
| (query, doc.page_content) for doc in candidates | |
| ]) | |
| ranked = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True) | |
| top = [doc for _, doc in ranked[: settings.hybrid_k]] | |
| # Compress and return | |
| return self._compressor.compress_documents(top, query=query) | |
| def _get_relevant_documents(self, query: str) -> List[Document]: | |
| import asyncio | |
| return asyncio.get_event_loop().run_until_complete( | |
| self._aget_relevant_documents(query) | |
| ) |