Spaces:
Running
Running
| import os | |
| import pickle | |
| import logging | |
| import platform | |
| import gradio as gr | |
| from llama_cpp import Llama | |
| from huggingface_hub import hf_hub_download | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # Qdrant filter models | |
| from qdrant_client.http.models import Filter, FieldCondition, MatchValue | |
| # ====================== LOGGING ====================== | |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # ====================== CONFIG ====================== | |
| repo_id = "robertolofaro/articles-model" | |
| BACKENDS = { | |
| "FAISS - RAG (HNSW)": "FAISS", | |
| "Qdrant - RAG": "Qdrant", | |
| } | |
| _HERE = os.path.dirname(os.path.abspath(__file__)) | |
| METADATA_PATH = os.path.join(_HERE, "metadata.pkl") | |
| FAISS_PATH = os.path.join(_HERE, "faiss_hnsw") | |
| QDRANT_PATH = os.path.join(_HERE, "qdrant_db") | |
| QDRANT_COLLECTION = "articles" | |
| # ====================== GPU / HARDWARE DETECTION ====================== | |
| # Override everything with N_GPU_LAYERS env var when you need fine control. | |
| # Otherwise: CUDA → all layers on GPU (-1); Apple Silicon → Metal (-1); else CPU (0). | |
| def _detect_gpu_layers() -> int: | |
| override = os.environ.get("N_GPU_LAYERS") | |
| if override is not None: | |
| val = int(override) | |
| logger.info("N_GPU_LAYERS override: %d", val) | |
| return val | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| logger.info("CUDA detected — offloading all layers to GPU") | |
| return -1 | |
| except ImportError: | |
| pass | |
| if platform.system() == "Darwin" and platform.machine() == "arm64": | |
| logger.info("Apple Silicon / Metal detected — offloading all layers to GPU") | |
| return -1 | |
| logger.info("No GPU detected — running on CPU only") | |
| return 0 | |
| N_GPU_LAYERS = _detect_gpu_layers() | |
| # ====================== LOAD METADATA ====================== | |
| def _load_metadata(): | |
| """Load the DataFrame from metadata.pkl; return None on any failure.""" | |
| try: | |
| with open(METADATA_PATH, "rb") as f: | |
| df = pickle.load(f) | |
| logger.info("metadata.pkl loaded — %d rows, columns: %s", len(df), df.columns.tolist()) | |
| return df | |
| except FileNotFoundError: | |
| logger.error("metadata.pkl not found at %s", METADATA_PATH) | |
| except Exception as exc: | |
| logger.error("Failed to load metadata.pkl: %s", exc) | |
| return None | |
| _METADATA_DF = _load_metadata() | |
| def load_category_list(): | |
| """Return ['All categories'] + sorted unique article_category values.""" | |
| if _METADATA_DF is not None and "article_category" in _METADATA_DF.columns: | |
| cats = sorted(_METADATA_DF["article_category"].dropna().unique().tolist()) | |
| logger.info("Found %d categories", len(cats)) | |
| return ["All categories"] + cats | |
| logger.warning("article_category column not found — showing only 'All categories'") | |
| return ["All categories"] | |
| def load_articles_for_category(category: str): | |
| """Return ['All articles in category'] + sorted titles for the given category.""" | |
| default = ["All articles in category"] | |
| if _METADATA_DF is None or "article_title" not in _METADATA_DF.columns: | |
| return default | |
| if category in ("All categories", None, ""): | |
| titles = sorted(_METADATA_DF["article_title"].dropna().unique().tolist()) | |
| else: | |
| mask = _METADATA_DF["article_category"] == category | |
| titles = sorted(_METADATA_DF.loc[mask, "article_title"].dropna().unique().tolist()) | |
| return default + titles | |
| CATEGORY_LIST = load_category_list() | |
| # ====================== LOAD LLM ====================== | |
| # LOCAL_MODEL_PATH env var lets you point to a local GGUF and skip the HF download. | |
| # N_THREADS env var overrides thread count (default: 4 on CPU, 2 on GPU). | |
| def _load_llm() -> Llama: | |
| local_model = os.environ.get("LOCAL_MODEL_PATH") | |
| if local_model and os.path.isfile(local_model): | |
| model_path = local_model | |
| logger.info("Using local model at %s", model_path) | |
| else: | |
| if local_model: | |
| logger.warning("LOCAL_MODEL_PATH set but file not found (%s) — downloading from HF", local_model) | |
| logger.info("Downloading model from HF hub (%s)…", repo_id) | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="articles-Q4_K_M.gguf", | |
| repo_type="model", | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| default_threads = 2 if N_GPU_LAYERS != 0 else 4 | |
| n_threads = int(os.environ.get("N_THREADS", default_threads)) | |
| logger.info("Llama init: n_gpu_layers=%d, n_threads=%d", N_GPU_LAYERS, n_threads) | |
| return Llama( | |
| model_path=model_path, | |
| n_ctx=8192, | |
| n_threads=n_threads, | |
| n_batch=512, | |
| n_ubatch=512, | |
| n_gpu_layers=N_GPU_LAYERS, | |
| verbose=False, | |
| ) | |
| llm = _load_llm() | |
| # ====================== RAG CACHE ====================== | |
| # ====================== VECTOR STORES ====================== | |
| vectorstores: dict = {} | |
| def get_vectorstore(backend_name: str): | |
| if backend_name in vectorstores: | |
| return vectorstores[backend_name] | |
| try: | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="BAAI/bge-small-en-v1.5", | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| if backend_name == "FAISS": | |
| # Modern recommended import (still under langchain-community) | |
| from langchain_community.vectorstores import FAISS | |
| vs = FAISS.load_local( | |
| FAISS_PATH, | |
| embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| logger.info("FAISS index loaded from %s", FAISS_PATH) | |
| elif backend_name == "Qdrant": | |
| # Modern Qdrant integration | |
| from langchain_qdrant import QdrantVectorStore | |
| from qdrant_client import QdrantClient | |
| client = QdrantClient( | |
| path=QDRANT_PATH, # path to your qdrant_db folder | |
| timeout=60, | |
| ) | |
| vs = QdrantVectorStore( | |
| client=client, | |
| collection_name=QDRANT_COLLECTION, | |
| embedding=embeddings, | |
| ) | |
| logger.info("Qdrant collection '%s' loaded from %s", | |
| QDRANT_COLLECTION, QDRANT_PATH) | |
| else: | |
| # fallback to FAISS | |
| from langchain_community.vectorstores import FAISS | |
| vs = FAISS.load_local( | |
| FAISS_PATH, | |
| embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| vectorstores[backend_name] = vs | |
| logger.info("Vector store '%s' loaded successfully", backend_name) | |
| return vs | |
| except Exception as exc: | |
| logger.error("Failed to load vector store '%s': %s", backend_name, exc) | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return None | |
| def _rag_search(vs, query: str, k: int, article_filter: str, category_filter: str): | |
| """ | |
| Similarity search with optional metadata filtering. | |
| """ | |
| want_title = None if article_filter in (None, "", "All articles in category") else article_filter | |
| want_category = None if category_filter in (None, "", "All categories") else category_filter | |
| backend_type = type(vs).__name__ | |
| ## potential security fix as catchall for FAISS search failure | |
| #if "FAISS" in backend_type: | |
| #try: | |
| # pool_size = min(k * 10, 80) | |
| # pool = vs.similarity_search(query, k=pool_size) | |
| # | |
| # # ... rest of your filtering code ... | |
| # | |
| #except Exception as e: | |
| # logger.error("FAISS similarity_search failed: %s", e) | |
| # # Fallback: try without k limit or return empty | |
| # return vs.similarity_search(query, k=k) | |
| if "FAISS" in backend_type: | |
| # FAISS: post-filtering (unchanged) | |
| pool_size = min(k * 10, 80) | |
| pool = vs.similarity_search(query, k=pool_size) | |
| filtered = [] | |
| for doc in pool: | |
| meta = doc.metadata | |
| if want_title and meta.get("article_title") != want_title: | |
| continue | |
| if want_category and meta.get("article_category") != want_category: | |
| continue | |
| filtered.append(doc) | |
| if len(filtered) >= k: | |
| break | |
| if not filtered and (want_title or want_category): | |
| logger.warning( | |
| "FAISS post-filter (title=%r, cat=%r) matched 0 docs — returning unfiltered top-%d", | |
| want_title, want_category, k | |
| ) | |
| return pool[:k] | |
| logger.info( | |
| "FAISS post-filter (title=%r, cat=%r) → %d/%d docs kept", | |
| want_title, want_category, len(filtered), len(pool) | |
| ) | |
| return filtered | |
| else: | |
| # === QDRANT - FIXED METADATA FILTER === | |
| from qdrant_client.http.models import Filter, FieldCondition, MatchValue | |
| conditions = [] | |
| if want_title: | |
| conditions.append( | |
| FieldCondition( | |
| key="metadata.article_title", # ← Fixed: metadata. prefix | |
| match=MatchValue(value=want_title) | |
| ) | |
| ) | |
| elif want_category: | |
| conditions.append( | |
| FieldCondition( | |
| key="metadata.article_category", # ← Fixed: metadata. prefix | |
| match=MatchValue(value=want_category) | |
| ) | |
| ) | |
| filter_dict = Filter(must=conditions) if conditions else None | |
| try: | |
| docs = vs.similarity_search( | |
| query, | |
| k=k, | |
| filter=filter_dict | |
| ) | |
| logger.info( | |
| "Qdrant search (filter=%s) → %d docs", | |
| "title" if want_title else "category" if want_category else "none", | |
| len(docs) | |
| ) | |
| return docs | |
| except Exception as e: | |
| logger.error("Qdrant search failed with filter: %s", e) | |
| # Fallback: search without filter | |
| logger.warning("Falling back to unfiltered Qdrant search") | |
| return vs.similarity_search(query, k=k) | |
| # ====================== SYSTEM PROMPT ====================== | |
| SYSTEM_PROMPT = """You are the reference expert for the articles contained in the training \ | |
| of this model, all extracted from the website robertolofaro.com, and all focused on change. | |
| IMPORTANT: Relevant article excerpts retrieved via semantic search will be injected \ | |
| directly in the user message under the heading "Context:". You MUST use those excerpts \ | |
| as the primary source for your answer. Do not speculate about whether you have access \ | |
| to articles — the context IS provided inline when available. | |
| # Your Mission | |
| When a user asks a question, provide a structured response based ONLY on the article \ | |
| content provided in the Context section. Do not draw on general knowledge outside those \ | |
| sources. Do not provide article titles or article IDs — provide only the concepts the \ | |
| articles express. | |
| # Response Format | |
| 1. Executive Summary: A 2-3 sentence overview answering the core query. | |
| 2. Guidelines & Hints: A markdown list of specific answers/guidelines/hints found in \ | |
| the source material.""" | |
| # ====================== GENERATION FUNCTION ====================== | |
| def generate_response( | |
| message, history, | |
| rag_mode, category_filter, article_filter, | |
| max_tokens, temperature, top_p, repeat_penalty, | |
| suppress_thinking, | |
| ): | |
| # Strip any /nothink the user may have typed manually | |
| clean_message = message.replace("/nothink", "").strip() | |
| # Build prompt with last 4 history turns for context window economy | |
| full_prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" | |
| for msg in history[-4:]: | |
| full_prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" | |
| # --- RAG retrieval --- | |
| backend = BACKENDS.get(rag_mode) | |
| context = "" | |
| if backend: | |
| vs = get_vectorstore(backend) | |
| if vs: | |
| try: | |
| docs = _rag_search( | |
| vs, clean_message, k=5, | |
| article_filter=article_filter, | |
| category_filter=category_filter, | |
| ) | |
| if docs: | |
| context = "\n\n".join( | |
| f"[Article: {doc.metadata.get('article_title', 'N/A')}] " | |
| f"{doc.page_content[:700]}" | |
| for doc in docs | |
| ) | |
| logger.info( | |
| "RAG: %d chunks injected (article=%r, cat=%r)", | |
| len(docs), article_filter, category_filter, | |
| ) | |
| else: | |
| logger.warning("RAG returned 0 chunks — answering without context") | |
| except Exception as exc: | |
| logger.error("RAG retrieval failed: %s", exc) | |
| # Qwen3 /nothink MUST appear on its own line at the very end of the user turn. | |
| # A leading space (e.g. " /nothink") is NOT recognised by the tokeniser. | |
| nothink_suffix = "\n/nothink" if suppress_thinking else "" | |
| if context: | |
| full_prompt += ( | |
| f"<|im_start|>user\nContext:\n{context}\n\n" | |
| f"Question: {clean_message}{nothink_suffix}<|im_end|>\n" | |
| ) | |
| else: | |
| full_prompt += ( | |
| f"<|im_start|>user\n{clean_message}{nothink_suffix}<|im_end|>\n" | |
| ) | |
| full_prompt += "<|im_start|>assistant\n" | |
| # Sanitise generation params | |
| max_tokens_val = int(max_tokens) if max_tokens is not None else 900 | |
| temp_val = float(temperature) if temperature is not None else 0.65 | |
| top_p_val = float(top_p) if top_p is not None else 0.9 | |
| rep_penalty_val = float(repeat_penalty) if repeat_penalty is not None else 1.1 | |
| partial_text = "" | |
| for chunk in llm( | |
| full_prompt, | |
| max_tokens=max_tokens_val, | |
| temperature=temp_val, | |
| top_p=top_p_val, | |
| repeat_penalty=rep_penalty_val, | |
| stop=["<|im_end|>", "<|im_start|>"], | |
| stream=True, | |
| ): | |
| token = chunk["choices"][0]["text"] | |
| partial_text += token | |
| yield partial_text | |
| # ====================== GRADIO INTERFACE ====================== | |
| with gr.Blocks(title="Article Q&A model") as demo: | |
| gr.Markdown("# sourcing 350+ articles on change") | |
| gr.Markdown( | |
| "Qwen3.5-4B DoRA fine-tuned on 350+ articles on change from robertolofaro.com — " | |
| "experimental demo on CPU-only, to test embedding methods (takes a few minutes, " | |
| "you can restrict by category, and then a specific article) — updated as of 2026-05-05" | |
| ) | |
| gr.Markdown( | |
| "**NOTAM:** by querying this model you access the articles and metadata " | |
| "available on robertolofaro.com and GitHub. " | |
| "Answers reflect the article corpus only — do not treat them as advice, " | |
| "just expression of a position derived from material contained within the articles. " | |
| "If you want to read actual positions expressed within articles, you can read the articles " | |
| "(see the model repository for all links to the available options)." | |
| ) | |
| gr.Markdown( | |
| "If, after getting an answer, you want something tailored to your context, " | |
| "contact a consultant (myself included)." | |
| ) | |
| with gr.Row(): | |
| rag_mode = gr.Radio( | |
| choices=list(BACKENDS.keys()), | |
| value="FAISS - RAG (HNSW)", | |
| label="Retrieval backend", | |
| ) | |
| suppress_thinking = gr.Checkbox( | |
| value=True, | |
| label="Suppress model thinking (/nothink)", | |
| info="Uncheck to see the model's reasoning chain", | |
| ) | |
| with gr.Row(): | |
| category_filter = gr.Dropdown( | |
| choices=CATEGORY_LIST, | |
| value="All categories", | |
| label="Filter by category", | |
| info=f"{len(CATEGORY_LIST) - 1} categories available", | |
| ) | |
| article_filter = gr.Dropdown( | |
| choices=["All articles in category"], | |
| value="All articles in category", | |
| label="Narrow to specific article (optional)", | |
| info="Select a category first to populate this list", | |
| ) | |
| # Dynamically populate the article dropdown when category changes | |
| def update_article_dropdown(category): | |
| articles = load_articles_for_category(category) | |
| return gr.Dropdown(choices=articles, value=articles[0]) | |
| category_filter.change( | |
| fn=update_article_dropdown, | |
| inputs=category_filter, | |
| outputs=article_filter, | |
| ) | |
| with gr.Accordion("Advanced Generation Parameters", open=False): | |
| max_tokens = gr.Slider(256, 2048, value=900, step=64, label="Max Tokens") | |
| temperature = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p") | |
| repeat_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repeat Penalty") | |
| gr.ChatInterface( | |
| fn=generate_response, | |
| additional_inputs=[ | |
| rag_mode, category_filter, article_filter, | |
| max_tokens, temperature, top_p, repeat_penalty, | |
| suppress_thinking, | |
| ], | |
| cache_examples=False, | |
| examples=[ | |
| ["What is the potential for Italy?"], | |
| ["What is the potential for Turin?"], | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch() | |