File size: 17,588 Bytes
344db6e
5f5a49f
762b148
196a72d
 
 
 
 
5f5a49f
3545fe7
 
344db6e
762b148
 
 
 
5f5a49f
7af5e1b
5f5a49f
 
 
196a72d
5f5a49f
 
196a72d
 
 
 
7af5e1b
e26d588
196a72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762b148
 
 
5f5a49f
762b148
5f5a49f
762b148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f5a49f
762b148
 
5f5a49f
 
196a72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7af5e1b
 
762b148
3545fe7
 
5f5a49f
 
7af5e1b
 
3545fe7
5f5a49f
762b148
 
 
 
3545fe7
b9f2f8c
3545fe7
5f5a49f
3545fe7
 
 
 
 
 
 
5f5a49f
3545fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f5a49f
3545fe7
998b96e
3545fe7
 
 
 
 
 
7af5e1b
762b148
5f5a49f
3545fe7
762b148
 
3545fe7
 
5f5a49f
 
196a72d
 
 
 
 
 
 
3545fe7
d462d5f
 
 
 
 
 
 
 
 
 
 
 
 
196a72d
3545fe7
 
196a72d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3545fe7
 
196a72d
 
 
 
 
3545fe7
196a72d
 
 
 
3545fe7
 
 
 
 
196a72d
3545fe7
 
 
 
 
 
196a72d
3545fe7
 
 
 
 
 
196a72d
3545fe7
 
 
 
 
 
 
 
 
 
 
 
 
 
196a72d
3545fe7
 
 
 
 
196a72d
7af5e1b
762b148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344db6e
762b148
 
344db6e
5f5a49f
 
762b148
 
 
 
 
 
196a72d
762b148
 
196a72d
7af5e1b
285322b
 
 
762b148
285322b
7af5e1b
 
285322b
 
 
 
196a72d
 
 
 
 
762b148
 
 
 
 
 
196a72d
 
 
 
762b148
196a72d
762b148
 
 
196a72d
 
 
285322b
 
762b148
 
 
 
285322b
196a72d
 
 
285322b
 
 
762b148
196a72d
 
 
 
285322b
7af5e1b
285322b
 
7af5e1b
 
 
 
285322b
 
 
762b148
7af5e1b
 
285322b
762b148
285322b
 
 
fc5ec95
 
9251225
196a72d
fc5ec95
 
 
 
762b148
196a72d
 
 
fc5ec95
 
 
 
 
285322b
 
 
 
 
762b148
 
 
 
 
 
 
 
 
 
 
 
 
 
285322b
 
762b148
 
 
 
285322b
 
762b148
 
 
 
 
 
 
 
 
 
 
285322b
196a72d
 
 
 
285322b
 
 
762b148
 
 
 
 
 
285322b
762b148
 
285322b
 
 
 
7af5e1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
import os
import pickle
import logging
import platform

import gradio as gr
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
from langchain_huggingface import HuggingFaceEmbeddings
# Qdrant filter models
from qdrant_client.http.models import Filter, FieldCondition, MatchValue

# ====================== LOGGING ======================
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
logger = logging.getLogger(__name__)

# ====================== CONFIG ======================
repo_id = "robertolofaro/articles-model"

BACKENDS = {
    "FAISS - RAG (HNSW)": "FAISS",
    "Qdrant - RAG": "Qdrant",
}

_HERE             = os.path.dirname(os.path.abspath(__file__))
METADATA_PATH     = os.path.join(_HERE, "metadata.pkl")
FAISS_PATH        = os.path.join(_HERE, "faiss_hnsw")
QDRANT_PATH       = os.path.join(_HERE, "qdrant_db")
QDRANT_COLLECTION = "articles"

# ====================== GPU / HARDWARE DETECTION ======================
# Override everything with N_GPU_LAYERS env var when you need fine control.
# Otherwise: CUDA β†’ all layers on GPU (-1); Apple Silicon β†’ Metal (-1); else CPU (0).
def _detect_gpu_layers() -> int:
    override = os.environ.get("N_GPU_LAYERS")
    if override is not None:
        val = int(override)
        logger.info("N_GPU_LAYERS override: %d", val)
        return val
    try:
        import torch
        if torch.cuda.is_available():
            logger.info("CUDA detected β€” offloading all layers to GPU")
            return -1
    except ImportError:
        pass
    if platform.system() == "Darwin" and platform.machine() == "arm64":
        logger.info("Apple Silicon / Metal detected β€” offloading all layers to GPU")
        return -1
    logger.info("No GPU detected β€” running on CPU only")
    return 0

N_GPU_LAYERS = _detect_gpu_layers()

# ====================== LOAD METADATA ======================
def _load_metadata():
    """Load the DataFrame from metadata.pkl; return None on any failure."""
    try:
        with open(METADATA_PATH, "rb") as f:
            df = pickle.load(f)
        logger.info("metadata.pkl loaded β€” %d rows, columns: %s", len(df), df.columns.tolist())
        return df
    except FileNotFoundError:
        logger.error("metadata.pkl not found at %s", METADATA_PATH)
    except Exception as exc:
        logger.error("Failed to load metadata.pkl: %s", exc)
    return None

_METADATA_DF = _load_metadata()


def load_category_list():
    """Return ['All categories'] + sorted unique article_category values."""
    if _METADATA_DF is not None and "article_category" in _METADATA_DF.columns:
        cats = sorted(_METADATA_DF["article_category"].dropna().unique().tolist())
        logger.info("Found %d categories", len(cats))
        return ["All categories"] + cats
    logger.warning("article_category column not found β€” showing only 'All categories'")
    return ["All categories"]


def load_articles_for_category(category: str):
    """Return ['All articles in category'] + sorted titles for the given category."""
    default = ["All articles in category"]
    if _METADATA_DF is None or "article_title" not in _METADATA_DF.columns:
        return default
    if category in ("All categories", None, ""):
        titles = sorted(_METADATA_DF["article_title"].dropna().unique().tolist())
    else:
        mask   = _METADATA_DF["article_category"] == category
        titles = sorted(_METADATA_DF.loc[mask, "article_title"].dropna().unique().tolist())
    return default + titles


CATEGORY_LIST = load_category_list()

# ====================== LOAD LLM ======================
# LOCAL_MODEL_PATH env var lets you point to a local GGUF and skip the HF download.
# N_THREADS env var overrides thread count (default: 4 on CPU, 2 on GPU).
def _load_llm() -> Llama:
    local_model = os.environ.get("LOCAL_MODEL_PATH")
    if local_model and os.path.isfile(local_model):
        model_path = local_model
        logger.info("Using local model at %s", model_path)
    else:
        if local_model:
            logger.warning("LOCAL_MODEL_PATH set but file not found (%s) β€” downloading from HF", local_model)
        logger.info("Downloading model from HF hub (%s)…", repo_id)
        model_path = hf_hub_download(
            repo_id=repo_id,
            filename="articles-Q4_K_M.gguf",
            repo_type="model",
            token=os.environ.get("HF_TOKEN"),
        )

    default_threads = 2 if N_GPU_LAYERS != 0 else 4
    n_threads = int(os.environ.get("N_THREADS", default_threads))
    logger.info("Llama init: n_gpu_layers=%d, n_threads=%d", N_GPU_LAYERS, n_threads)

    return Llama(
        model_path=model_path,
        n_ctx=8192,
        n_threads=n_threads,
        n_batch=512,
        n_ubatch=512,
        n_gpu_layers=N_GPU_LAYERS,
        verbose=False,
    )

llm = _load_llm()

# ====================== RAG CACHE ======================

# ====================== VECTOR STORES ======================
vectorstores: dict = {}

def get_vectorstore(backend_name: str):
    if backend_name in vectorstores:
        return vectorstores[backend_name]

    try:
        embeddings = HuggingFaceEmbeddings(
            model_name="BAAI/bge-small-en-v1.5",
            encode_kwargs={"normalize_embeddings": True},
        )

        if backend_name == "FAISS":
            # Modern recommended import (still under langchain-community)
            from langchain_community.vectorstores import FAISS
            vs = FAISS.load_local(
                FAISS_PATH, 
                embeddings, 
                allow_dangerous_deserialization=True
            )
            logger.info("FAISS index loaded from %s", FAISS_PATH)

        elif backend_name == "Qdrant":
            # Modern Qdrant integration
            from langchain_qdrant import QdrantVectorStore
            from qdrant_client import QdrantClient

            client = QdrantClient(
                path=QDRANT_PATH,      # path to your qdrant_db folder
                timeout=60,
            )

            vs = QdrantVectorStore(
                client=client,
                collection_name=QDRANT_COLLECTION,
                embedding=embeddings,
            )
            logger.info("Qdrant collection '%s' loaded from %s", 
                       QDRANT_COLLECTION, QDRANT_PATH)

        else:
            # fallback to FAISS
            from langchain_community.vectorstores import FAISS
            vs = FAISS.load_local(
                FAISS_PATH, 
                embeddings, 
                allow_dangerous_deserialization=True
            )

        vectorstores[backend_name] = vs
        logger.info("Vector store '%s' loaded successfully", backend_name)
        return vs

    except Exception as exc:
        logger.error("Failed to load vector store '%s': %s", backend_name, exc)
        import traceback
        logger.error(traceback.format_exc())
        return None

def _rag_search(vs, query: str, k: int, article_filter: str, category_filter: str):
    """
    Similarity search with optional metadata filtering.
    """
    want_title    = None if article_filter  in (None, "", "All articles in category") else article_filter
    want_category = None if category_filter in (None, "", "All categories")           else category_filter

    backend_type = type(vs).__name__
    
    ## potential security fix as catchall for FAISS search failure
    #if "FAISS" in backend_type:
    #try:
    #    pool_size = min(k * 10, 80)
    #    pool = vs.similarity_search(query, k=pool_size)
    #
    #    # ... rest of your filtering code ...
    #    
    #except Exception as e:
    #    logger.error("FAISS similarity_search failed: %s", e)
    #    # Fallback: try without k limit or return empty
    #    return vs.similarity_search(query, k=k)

    if "FAISS" in backend_type:
        # FAISS: post-filtering (unchanged)
        pool_size = min(k * 10, 80)
        pool = vs.similarity_search(query, k=pool_size)

        filtered = []
        for doc in pool:
            meta = doc.metadata
            if want_title and meta.get("article_title") != want_title:
                continue
            if want_category and meta.get("article_category") != want_category:
                continue
            filtered.append(doc)
            if len(filtered) >= k:
                break

        if not filtered and (want_title or want_category):
            logger.warning(
                "FAISS post-filter (title=%r, cat=%r) matched 0 docs β€” returning unfiltered top-%d",
                want_title, want_category, k
            )
            return pool[:k]

        logger.info(
            "FAISS post-filter (title=%r, cat=%r) β†’ %d/%d docs kept",
            want_title, want_category, len(filtered), len(pool)
        )
        return filtered

    else:
        # === QDRANT - FIXED METADATA FILTER ===
        from qdrant_client.http.models import Filter, FieldCondition, MatchValue

        conditions = []

        if want_title:
            conditions.append(
                FieldCondition(
                    key="metadata.article_title",      # ← Fixed: metadata. prefix
                    match=MatchValue(value=want_title)
                )
            )
        elif want_category:
            conditions.append(
                FieldCondition(
                    key="metadata.article_category",   # ← Fixed: metadata. prefix
                    match=MatchValue(value=want_category)
                )
            )

        filter_dict = Filter(must=conditions) if conditions else None

        try:
            docs = vs.similarity_search(
                query, 
                k=k, 
                filter=filter_dict
            )
            logger.info(
                "Qdrant search (filter=%s) β†’ %d docs", 
                "title" if want_title else "category" if want_category else "none", 
                len(docs)
            )
            return docs

        except Exception as e:
            logger.error("Qdrant search failed with filter: %s", e)
            # Fallback: search without filter
            logger.warning("Falling back to unfiltered Qdrant search")
            return vs.similarity_search(query, k=k)

# ====================== SYSTEM PROMPT ======================
SYSTEM_PROMPT = """You are the reference expert for the articles contained in the training \
of this model, all extracted from the website robertolofaro.com, and all focused on change.

IMPORTANT: Relevant article excerpts retrieved via semantic search will be injected \
directly in the user message under the heading "Context:". You MUST use those excerpts \
as the primary source for your answer. Do not speculate about whether you have access \
to articles β€” the context IS provided inline when available.

# Your Mission
When a user asks a question, provide a structured response based ONLY on the article \
content provided in the Context section. Do not draw on general knowledge outside those \
sources. Do not provide article titles or article IDs β€” provide only the concepts the \
articles express.

# Response Format
1. Executive Summary: A 2-3 sentence overview answering the core query.
2. Guidelines & Hints: A markdown list of specific answers/guidelines/hints found in \
the source material."""


# ====================== GENERATION FUNCTION ======================
def generate_response(
    message, history,
    rag_mode, category_filter, article_filter,
    max_tokens, temperature, top_p, repeat_penalty,
    suppress_thinking,
):
    # Strip any /nothink the user may have typed manually
    clean_message = message.replace("/nothink", "").strip()

    # Build prompt with last 4 history turns for context window economy
    full_prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
    for msg in history[-4:]:
        full_prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"

    # --- RAG retrieval ---
    backend = BACKENDS.get(rag_mode)
    context = ""

    if backend:
        vs = get_vectorstore(backend)
        if vs:
            try:
                docs = _rag_search(
                    vs, clean_message, k=5,
                    article_filter=article_filter,
                    category_filter=category_filter,
                )
                if docs:
                    context = "\n\n".join(
                        f"[Article: {doc.metadata.get('article_title', 'N/A')}] "
                        f"{doc.page_content[:700]}"
                        for doc in docs
                    )
                    logger.info(
                        "RAG: %d chunks injected (article=%r, cat=%r)",
                        len(docs), article_filter, category_filter,
                    )
                else:
                    logger.warning("RAG returned 0 chunks β€” answering without context")
            except Exception as exc:
                logger.error("RAG retrieval failed: %s", exc)

    # Qwen3 /nothink MUST appear on its own line at the very end of the user turn.
    # A leading space (e.g. " /nothink") is NOT recognised by the tokeniser.
    nothink_suffix = "\n/nothink" if suppress_thinking else ""

    if context:
        full_prompt += (
            f"<|im_start|>user\nContext:\n{context}\n\n"
            f"Question: {clean_message}{nothink_suffix}<|im_end|>\n"
        )
    else:
        full_prompt += (
            f"<|im_start|>user\n{clean_message}{nothink_suffix}<|im_end|>\n"
        )

    full_prompt += "<|im_start|>assistant\n"

    # Sanitise generation params
    max_tokens_val  = int(max_tokens)        if max_tokens      is not None else 900
    temp_val        = float(temperature)     if temperature     is not None else 0.65
    top_p_val       = float(top_p)           if top_p           is not None else 0.9
    rep_penalty_val = float(repeat_penalty)  if repeat_penalty  is not None else 1.1

    partial_text = ""
    for chunk in llm(
        full_prompt,
        max_tokens=max_tokens_val,
        temperature=temp_val,
        top_p=top_p_val,
        repeat_penalty=rep_penalty_val,
        stop=["<|im_end|>", "<|im_start|>"],
        stream=True,
    ):
        token = chunk["choices"][0]["text"]
        partial_text += token
        yield partial_text


# ====================== GRADIO INTERFACE ======================
with gr.Blocks(title="Article Q&A model") as demo:
    gr.Markdown("# sourcing 350+ articles on change")
    gr.Markdown(
        "Qwen3.5-4B DoRA fine-tuned on 350+ articles on change from robertolofaro.com β€” "
        "experimental demo on CPU-only, to test embedding methods (takes a few minutes, "
        "you can restrict by category, and then a specific article) β€” updated as of 2026-05-05"
    )
    gr.Markdown(
        "**NOTAM:** by querying this model you access the articles and metadata "
        "available on robertolofaro.com and GitHub.  "
        "Answers reflect the article corpus only β€” do not treat them as advice, "
        "just expression of a position derived from material contained within the articles. "
        "If you want to read actual positions expressed within articles, you can read the articles "
        "(see the model repository for all links to the available options)."
    )
    gr.Markdown(
        "If, after getting an answer, you want something tailored to your context, "
        "contact a consultant (myself included)."
    )

    with gr.Row():
        rag_mode = gr.Radio(
            choices=list(BACKENDS.keys()),
            value="FAISS - RAG (HNSW)",
            label="Retrieval backend",
        )
        suppress_thinking = gr.Checkbox(
            value=True,
            label="Suppress model thinking (/nothink)",
            info="Uncheck to see the model's reasoning chain",
        )

    with gr.Row():
        category_filter = gr.Dropdown(
            choices=CATEGORY_LIST,
            value="All categories",
            label="Filter by category",
            info=f"{len(CATEGORY_LIST) - 1} categories available",
        )
        article_filter = gr.Dropdown(
            choices=["All articles in category"],
            value="All articles in category",
            label="Narrow to specific article (optional)",
            info="Select a category first to populate this list",
        )

    # Dynamically populate the article dropdown when category changes
    def update_article_dropdown(category):
        articles = load_articles_for_category(category)
        return gr.Dropdown(choices=articles, value=articles[0])

    category_filter.change(
        fn=update_article_dropdown,
        inputs=category_filter,
        outputs=article_filter,
    )

    with gr.Accordion("Advanced Generation Parameters", open=False):
        max_tokens     = gr.Slider(256, 2048, value=900,  step=64,   label="Max Tokens")
        temperature    = gr.Slider(0.0,  1.0, value=0.65, step=0.05, label="Temperature")
        top_p          = gr.Slider(0.0,  1.0, value=0.9,  step=0.05, label="Top-p")
        repeat_penalty = gr.Slider(1.0,  2.0, value=1.1,  step=0.05, label="Repeat Penalty")

    gr.ChatInterface(
        fn=generate_response,
        additional_inputs=[
            rag_mode, category_filter, article_filter,
            max_tokens, temperature, top_p, repeat_penalty,
            suppress_thinking,
        ],
        cache_examples=False,
        examples=[
            ["What is the potential for Italy?"],
            ["What is the potential for Turin?"],
        ],
    )

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=1).launch()