Spaces:
Running
Running
| # ============================================================ | |
| # IMPORTS | |
| # ============================================================ | |
| import re | |
| import os | |
| import math | |
| import pickle | |
| import requests | |
| from collections import Counter | |
| import numpy as np | |
| import pandas as pd | |
| import faiss | |
| import PyPDF2 | |
| import torch | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import Patch | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from langdetect import detect, DetectorFactory | |
| from gtts import gTTS | |
| from transformers import pipeline as hf_pipeline | |
| from transformers import pipeline | |
| from datetime import datetime | |
| from groq import Groq | |
| from sklearn.preprocessing import MinMaxScaler | |
| from scipy import stats | |
| DetectorFactory.seed = 0 | |
| # ============================================================ | |
| # GROQ SETUP | |
| # ============================================================ | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "") | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| print(f"DEBUG — Groq Key loaded: {bool(GROQ_API_KEY)}") | |
| # ============================================================ | |
| # GLOBAL STATE | |
| # ============================================================ | |
| KB_TEXTS = [] | |
| KB_META = [] | |
| FAISS_INDEX = None | |
| KB_EMB = None | |
| DOC_TYPE_INFO = {"type": "📄 General", "is_economic": False, "score": 0} | |
| PER_FILE_INFO = {} | |
| CHAT_STATS = {"questions": 0, "found": 0, "not_found": 0} | |
| MIN_SIMILARITY = 0.10 | |
| PERSIST_DIR = "/tmp" | |
| KB_TEXTS_PATH = f"{PERSIST_DIR}/kb_texts.pkl" | |
| KB_META_PATH = f"{PERSIST_DIR}/kb_meta.pkl" | |
| FAISS_PATH = f"{PERSIST_DIR}/faiss.index" | |
| os.makedirs(PERSIST_DIR, exist_ok=True) | |
| # ============================================================ | |
| # PERSIST | |
| # ============================================================ | |
| def save_index(): | |
| if FAISS_INDEX is None or not KB_TEXTS: | |
| return "⚠️ No index to save." | |
| try: | |
| with open(KB_TEXTS_PATH, "wb") as f: pickle.dump(KB_TEXTS, f) | |
| with open(KB_META_PATH, "wb") as f: pickle.dump(KB_META, f) | |
| faiss.write_index(FAISS_INDEX, FAISS_PATH) | |
| return f"💾 Saved! {len(KB_TEXTS):,} chunks" | |
| except Exception as e: | |
| return f"❌ Save error: {e}" | |
| def load_saved_index(): | |
| global KB_TEXTS, KB_META, FAISS_INDEX, DOC_TYPE_INFO | |
| try: | |
| if not os.path.exists(FAISS_PATH): | |
| return "_No saved index found._" | |
| with open(KB_TEXTS_PATH, "rb") as f: KB_TEXTS = pickle.load(f) | |
| with open(KB_META_PATH, "rb") as f: KB_META = pickle.load(f) | |
| FAISS_INDEX = faiss.read_index(FAISS_PATH) | |
| DOC_TYPE_INFO = detect_document_type(KB_TEXTS) | |
| return f"✅ **Index loaded!** `{len(KB_TEXTS):,}` chunks\n🏷️ Type: **{DOC_TYPE_INFO['type']}**" | |
| except Exception as e: | |
| return f"❌ Load error: {e}" | |
| # ============================================================ | |
| # KEYWORDS & LEXICONS | |
| # ============================================================ | |
| ECONOMIC_KEYWORDS = [ | |
| "gdp","inflation","monetary","fiscal","forecast","exchange rate", | |
| "interest rate","unemployment","recession","growth rate","trade balance", | |
| "budget deficit","central bank","economic outlook","imf","world bank", | |
| "cpi","macro","revenue","expenditure","deficit","surplus","debt", | |
| "croissance","taux","banque centrale","prévision","économique","pib", | |
| "التضخم","الناتج المحلي","النمو الاقتصادي","البنك المركزي","سعر الصرف", | |
| ] | |
| MEDICAL_KEYWORDS = ["patient","diagnosis","treatment","clinical","hospital","symptom","disease"] | |
| LEGAL_KEYWORDS = ["article","law","contract","clause","jurisdiction","court","legal"] | |
| ACADEMIC_KEYWORDS = ["abstract","methodology","hypothesis","conclusion","references","doi","journal"] | |
| ECON_POSITIVE = [ | |
| "growth","recovery","surplus","improvement","stability","increase", | |
| "expansion","acceleration","resilience","upturn","robust","favorable", | |
| "strengthened","progress","rebound","optimistic","confidence","boom", | |
| "prosper","thrive","advance","gain","rise","positive","upward", | |
| "exceed","outperform","strong","healthy","dynamic","sustainable", | |
| "croissance","reprise","amélioration","stabilité","excédent","hausse", | |
| "expansion","dynamique","favorable","progrès","rebond","solide", | |
| "تعافي","نمو","استقرار","فائض","تحسّن","ارتفاع","توسع","إيجابي", | |
| "تقدم","قوي","ازدهار","انتعاش","تحسين","قوة", | |
| ] | |
| ECON_NEGATIVE = [ | |
| "deficit","recession","inflation","decline","contraction","debt", | |
| "crisis","deterioration","slowdown","downturn","unemployment","pressure", | |
| "risk","vulnerability","shock","uncertainty","war","sanctions", | |
| "drought","collapse","default","volatile","instability","weak", | |
| "fragile","pessimistic","loss","shrink","fall","negative","downward", | |
| "slump","stagnation","turbulence","disruption","imbalance","burden", | |
| "déficit","récession","crise","ralentissement","chômage","incertitude", | |
| "guerre","effondrement","instabilité","baisse","fragilité","pression", | |
| "عجز","تضخم","ركود","انكماش","أزمة","تدهور","بطالة","انخفاض", | |
| "ضغط","مخاطر","صدمة","عدم استقرار","هشاشة","ديون","عقوبات", | |
| ] | |
| ECON_TRIGGER = [ | |
| "deficit","risk","crisis","recession","shock","uncertainty", | |
| "slowdown","pressure","vulnerable","weak","deteriorat","downturn", | |
| "contraction","debt","unemployment","inflation","collapse","volatile", | |
| "instability","fragile","stagnation","disruption","sanctions","drought", | |
| "growth","recovery","improvement","surplus","stable","expansion", | |
| "resilience","rebound","strengthened","acceleration","robust", | |
| "favorable","progress","increase","upturn","confidence","boom", | |
| "gdp","forecast","outlook","trade","fiscal","monetary","exchange", | |
| "interest","budget","revenue","expenditure","policy","reform", | |
| "التضخم","الناتج","النمو","العجز","المخاطر","التوقعات","الميزانية", | |
| "croissance","déficit","récession","prévision","taux","politique", | |
| ] | |
| def economic_lexicon_score(text: str) -> float: | |
| text_lower = text.lower() | |
| pos = sum(1 for w in ECON_POSITIVE if w in text_lower) | |
| neg = sum(1 for w in ECON_NEGATIVE if w in text_lower) | |
| total = max(pos + neg, 1) | |
| return round((pos - neg) / total, 4) | |
| def detect_document_type(texts: list) -> dict: | |
| if not texts: | |
| return {"type":"📄 General","is_economic":False,"score":0,"confidence":0.0} | |
| full_text = " ".join(texts[:30]).lower() | |
| scores = { | |
| "economic": sum(1 for kw in ECONOMIC_KEYWORDS if kw in full_text), | |
| "medical" : sum(1 for kw in MEDICAL_KEYWORDS if kw in full_text), | |
| "legal" : sum(1 for kw in LEGAL_KEYWORDS if kw in full_text), | |
| "academic": sum(1 for kw in ACADEMIC_KEYWORDS if kw in full_text), | |
| "general" : 1, | |
| } | |
| doc_type = max(scores, key=scores.get) | |
| confidence = round(scores[doc_type] / max(sum(scores.values()), 1), 2) | |
| icons = { | |
| "economic":"📊 Economic","medical":"🏥 Medical", | |
| "legal":"⚖️ Legal","academic":"🎓 Academic","general":"📄 General", | |
| } | |
| return { | |
| "type" : icons.get(doc_type, "📄 General"), | |
| "raw_type" : doc_type, | |
| "is_economic": doc_type == "economic" and scores["economic"] >= 3, | |
| "score" : scores[doc_type], | |
| "confidence" : confidence, | |
| } | |
| # ============================================================ | |
| # AI MODELS — Ensemble: FinBERT 40% + XLM 30% + Lexicon 30% | |
| # ============================================================ | |
| WEIGHTS = {"finbert": 0.40, "xlm": 0.30, "lexicon": 0.30} | |
| print("⏳ Loading FinBERT...") | |
| try: | |
| finbert_pipe = pipeline( | |
| "text-classification", | |
| model="ProsusAI/finbert", | |
| tokenizer="ProsusAI/finbert", | |
| return_all_scores=True, | |
| device=0 if torch.cuda.is_available() else -1, | |
| ) | |
| FINBERT_OK = True | |
| print("✅ FinBERT loaded!") | |
| except Exception as e: | |
| print(f"⚠️ FinBERT failed: {e}") | |
| finbert_pipe = None | |
| FINBERT_OK = False | |
| print("⏳ Loading XLM-RoBERTa...") | |
| try: | |
| xlm_pipe = pipeline( | |
| "text-classification", | |
| model="cardiffnlp/twitter-xlm-roberta-base-sentiment", | |
| tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", | |
| return_all_scores=True, | |
| device=0 if torch.cuda.is_available() else -1, | |
| ) | |
| XLM_OK = True | |
| print("✅ XLM-RoBERTa loaded!") | |
| except Exception as e: | |
| print(f"⚠️ XLM-RoBERTa failed: {e}") | |
| xlm_pipe = None | |
| XLM_OK = False | |
| def normalize_clf(raw): | |
| if isinstance(raw, list) and raw and isinstance(raw[0], list): | |
| raw = raw[0] | |
| return raw if isinstance(raw, list) else [raw] | |
| def clf_finbert(text: str) -> float: | |
| if not FINBERT_OK or finbert_pipe is None: return 0.0 | |
| try: | |
| items = normalize_clf(finbert_pipe(text[:512])) | |
| d = {r["label"].lower(): float(r["score"]) for r in items} | |
| return round(d.get("positive", 0.0) - d.get("negative", 0.0), 4) | |
| except: return 0.0 | |
| def clf_xlm(text: str) -> float: | |
| if not XLM_OK or xlm_pipe is None: return 0.0 | |
| try: | |
| items = normalize_clf(xlm_pipe(text[:512])) | |
| d = {r["label"]: float(r["score"]) for r in items} | |
| pos = d.get("LABEL_2", d.get("positive", d.get("Positive", 0.0))) | |
| neg = d.get("LABEL_0", d.get("negative", d.get("Negative", 0.0))) | |
| return round(pos - neg, 4) | |
| except: return 0.0 | |
| def sentiment_score_numeric(text: str) -> float: | |
| fb = clf_finbert(text) | |
| xlm = clf_xlm(text) | |
| lex = economic_lexicon_score(text) | |
| return round(WEIGHTS["finbert"]*fb + WEIGHTS["xlm"]*xlm + WEIGHTS["lexicon"]*lex, 4) | |
| def run_sentiment(text: str): | |
| score = sentiment_score_numeric(text) | |
| if score > 0.05: sent = "Positive 😊" | |
| elif score < -0.05: sent = "Negative 😞" | |
| else: sent = "Neutral 😐" | |
| return sent, round(min(abs(score), 1.0), 4) | |
| def run_sentiment_detailed(text: str) -> str: | |
| fb = clf_finbert(text) | |
| xlm = clf_xlm(text) | |
| lex = economic_lexicon_score(text) | |
| final = sentiment_score_numeric(text) | |
| def bar(s): | |
| filled = max(0, min(10, round((s + 1) / 2 * 10))) | |
| icon = "🟩" if s > 0.05 else "🟥" if s < -0.05 else "🟨" | |
| return icon * filled + "⬜" * (10 - filled) | |
| label = "🟢 **Positive**" if final > 0.05 else "🔴 **Negative**" if final < -0.05 else "🟡 **Neutral**" | |
| return ( | |
| f"### 🏆 Ensemble Sentiment Breakdown\n\n" | |
| f"| Model | Score | Bar | Weight |\n|---|---|---|---|\n" | |
| f"| 🏦 FinBERT | `{fb:+.4f}` | {bar(fb)} | **40%** |\n" | |
| f"| 🌍 XLM-RoBERTa | `{xlm:+.4f}` | {bar(xlm)} | **30%** |\n" | |
| f"| 📖 Lexicon | `{lex:+.4f}` | {bar(lex)} | **30%** |\n" | |
| f"| ⚡ **Final** | **`{final:+.4f}`** | {bar(final)} | **100%** |\n\n" | |
| f"{label}" | |
| ) | |
| # ============================================================ | |
| # EMBEDDING + RERANKER + ASR | |
| # ============================================================ | |
| print("⏳ Loading Embedder, Reranker, ASR...") | |
| embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=512) | |
| asr = hf_pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| device=0 if torch.cuda.is_available() else -1, | |
| ) | |
| _ = embedder.encode(["warmup"], convert_to_numpy=True) | |
| print("✅ All models loaded!") | |
| _startup = load_saved_index() | |
| print(f"🔄 Startup load: {_startup}") | |
| # ============================================================ | |
| # RAG CORE | |
| # ============================================================ | |
| def clean_filename(path: str) -> str: | |
| return os.path.basename(str(path)) | |
| def detect_lang(text: str) -> str: | |
| try: | |
| return "ar" if str(detect(str(text)[:300])).startswith("ar") else "en" | |
| except: | |
| return "en" | |
| def extract_year_from_filename(filename: str): | |
| full_path = str(filename).replace("\\", "/") | |
| for part in reversed(full_path.split("/")): | |
| m = re.findall(r"\b(20\d{2}|19\d{2})\b", part) | |
| if m: return int(m[0]) | |
| for pat in [r'WEO[_\-\s]?(\d{4})', r'BOA[_\-\s]?(\d{4})', | |
| r'IMF[_\-\s]?(\d{4})', r'rapport[_\-\s]?(\d{4})', | |
| r'report[_\-\s]?(\d{4})']: | |
| m = re.search(pat, full_path, re.IGNORECASE) | |
| if m: return int(m.group(1)) | |
| all_y = re.findall(r'\b(19\d{2}|20\d{2})\b', full_path) | |
| return int(all_y[0]) if all_y else None | |
| def chunk_text(text, chunk_size=300, overlap=80): | |
| text = re.sub(r"\s+", " ", str(text)).strip() | |
| sentences = re.split(r"(?<=[.!?؟\n])\s+", text) | |
| chunks, current = [], "" | |
| for sent in sentences: | |
| if len(current) + len(sent) <= chunk_size: | |
| current += " " + sent | |
| else: | |
| if current.strip(): chunks.append(current.strip()) | |
| words = current.split() | |
| current = " ".join(words[-overlap // 5:]) + " " + sent if words else sent | |
| if current.strip(): chunks.append(current.strip()) | |
| return [c for c in chunks if len(c) > 30] | |
| def load_file(path): | |
| path = str(path) | |
| if path.endswith(".pdf"): | |
| pages = [] | |
| try: | |
| import pypdf | |
| with open(path, "rb") as f: | |
| reader = pypdf.PdfReader(f) | |
| for i, pg in enumerate(reader.pages[:50]): | |
| t = pg.extract_text() | |
| if t and t.strip(): pages.append({"text": t, "page": i+1}) | |
| except: pass | |
| if not pages: | |
| try: | |
| with open(path, "rb") as f: | |
| reader = PyPDF2.PdfReader(f) | |
| for i, pg in enumerate(reader.pages[:50]): | |
| t = pg.extract_text() | |
| if t and t.strip(): pages.append({"text": t, "page": i+1}) | |
| except: pass | |
| return pages or [{"text": "Could not extract text.", "page": 1}] | |
| if path.endswith(".docx"): | |
| try: | |
| from docx import Document | |
| doc = Document(path) | |
| pars = [p.text for p in doc.paragraphs if p.text.strip()] | |
| return [{"text": "\n".join(pars[i:i+50]), "page": i//50+1} | |
| for i in range(0, len(pars), 50)] or [{"text":"Empty DOCX.","page":1}] | |
| except Exception as e: | |
| return [{"text": f"DOCX error: {e}", "page": 1}] | |
| if path.endswith(".csv"): | |
| df = pd.read_csv(path) | |
| col = "text" if "text" in df.columns else df.columns[0] | |
| return [{"text": t, "page": i+1} | |
| for i, t in enumerate(df[col].dropna().astype(str))] | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| return [{"text": f.read(), "page": 1}] | |
| def build_index(files): | |
| global KB_TEXTS, KB_META, FAISS_INDEX, KB_EMB, DOC_TYPE_INFO, PER_FILE_INFO | |
| KB_TEXTS, KB_META, PER_FILE_INFO = [], [], {} | |
| if not files: raise gr.Error("⚠️ Upload at least one file.") | |
| file_paths = [] | |
| if not isinstance(files, list): files = [files] | |
| for f in files: | |
| if isinstance(f, str): file_paths.append(f) | |
| elif isinstance(f, dict): file_paths.append(f.get("path") or f.get("name") or str(f)) | |
| elif hasattr(f, "name"): file_paths.append(f.name) | |
| else: file_paths.append(str(f)) | |
| for p in file_paths: | |
| full_path = str(p) | |
| fname = clean_filename(full_path) | |
| year = extract_year_from_filename(fname) or extract_year_from_filename(full_path) | |
| pages = load_file(full_path) | |
| file_texts = [] | |
| for pg in pages: | |
| for ch in chunk_text(pg["text"]): | |
| KB_TEXTS.append(ch) | |
| KB_META.append({"name": fname, "lang": detect_lang(ch), | |
| "page": pg["page"], "year": year}) | |
| file_texts.append(ch) | |
| ti = detect_document_type(file_texts) | |
| ti["year"] = year | |
| PER_FILE_INFO[fname] = ti | |
| if not KB_TEXTS: raise gr.Error("⚠️ No text extracted.") | |
| KB_EMB = embedder.encode( | |
| KB_TEXTS, convert_to_numpy=True, | |
| normalize_embeddings=True, show_progress_bar=False | |
| ).astype("float32") | |
| FAISS_INDEX = faiss.IndexFlatIP(KB_EMB.shape[1]) | |
| FAISS_INDEX.add(KB_EMB) | |
| DOC_TYPE_INFO = detect_document_type(KB_TEXTS) | |
| lang_count = Counter(m["lang"] for m in KB_META) | |
| tbl = "| 📄 File | 📅 Year | 🏷️ Type | 🎯 Conf | 📦 Chunks |\n|---|---|---|---|---|\n" | |
| for fname, info in PER_FILE_INFO.items(): | |
| n = sum(1 for m in KB_META if m["name"] == fname) | |
| yr = str(info.get("year","N/A")) | |
| yrb = f"{yr} ✅" if yr not in ["None","N/A"] else "N/A ⚠️" | |
| badge = " 🟢" if info["is_economic"] else "" | |
| tbl += f"| `{fname}` | {yrb} | {info['type']}{badge} | {info['confidence']:.0%} | {n} |\n" | |
| ef = [f for f,i in PER_FILE_INFO.items() if i["is_economic"]] | |
| fmsg = ( | |
| f"\n\n🟢 **Economic files detected:** " + | |
| ", ".join(f"`{f}`" for f in ef) + | |
| "\n➡️ Go to **📈 7 · Forecast** tab to run predictions." | |
| ) if ef else "" | |
| save_index() | |
| return ( | |
| f"✅ **Index built!**\n\n" | |
| f"| | |\n|---|---|\n" | |
| f"| 📦 Total chunks | **{len(KB_TEXTS):,}** |\n" | |
| f"| 📄 Files | **{len(file_paths)}** |\n" | |
| f"| 🇸🇦 Arabic | **{lang_count.get('ar',0):,}** |\n" | |
| f"| 🇺🇸 English | **{lang_count.get('en',0):,}** |\n\n" | |
| f"---\n### 📋 Per-File Analysis\n\n{tbl}{fmsg}" | |
| ) | |
| def bm25_score(query_terms, doc, k1=1.5, b=0.75, avg_dl=200): | |
| try: | |
| if not KB_TEXTS or not isinstance(doc, str): return 0.0 | |
| dl, score = len(doc.split()), 0.0 | |
| df = Counter(doc.lower().split()) | |
| for term in query_terms: | |
| if not isinstance(term, str) or not term: continue | |
| tl = term.lower() | |
| n_doc = sum(1 for t in KB_TEXTS if isinstance(t,str) and tl in t.lower()) | |
| tf = df.get(tl, 0) | |
| idf = math.log((len(KB_TEXTS)+1)/(1+n_doc)) | |
| score += idf*(tf*(k1+1))/(tf+k1*(1-b+b*dl/max(avg_dl,1))) | |
| return score | |
| except: return 0.0 | |
| def rag_retrieve(query, k=5, top_n=3): | |
| if FAISS_INDEX is None or not KB_TEXTS: return [] | |
| try: | |
| q_emb = embedder.encode( | |
| [query], convert_to_numpy=True, normalize_embeddings=True | |
| ).astype("float32") | |
| scores, idx = FAISS_INDEX.search(q_emb, min(k*3, len(KB_TEXTS))) | |
| candidates, qterms = [], [t for t in re.findall(r"\w+", str(query).lower()) if t] | |
| for rank, i in enumerate(idx[0]): | |
| if i == -1: continue | |
| sem = float(scores[0][rank]) | |
| if sem < MIN_SIMILARITY: continue | |
| text = KB_TEXTS[i] | |
| if not isinstance(text, str): continue | |
| kw = bm25_score(qterms, text) | |
| lterms = [t for t in qterms if len(t) > 2] | |
| try: | |
| exact = all(re.search(rf"\b{re.escape(t)}\b", text.lower()) for t in lterms) if lterms else False | |
| except: exact = False | |
| hybrid = sem*0.6 + min(kw/10, 0.4) + (0.15 if exact else 0.0) | |
| candidates.append({ | |
| "idx": i, "sem": sem, "kw": kw, "exact": exact, "hybrid": hybrid, | |
| "lang": KB_META[i]["lang"], "file": KB_META[i]["name"], | |
| "page": KB_META[i]["page"], "year": KB_META[i].get("year"), | |
| "text": text, | |
| }) | |
| if not candidates: return [] | |
| ce_scores = reranker.predict([[query, c["text"]] for c in candidates]) | |
| for c, ce in zip(candidates, ce_scores): | |
| c["ce_score"] = float(ce) | |
| c["final"] = c["hybrid"]*0.4 + (float(ce)+10)/20*0.6 | |
| candidates.sort(key=lambda x: x["final"], reverse=True) | |
| for i, c in enumerate(candidates[:top_n]): c["rank"] = i+1 | |
| return candidates[:top_n] | |
| except Exception as e: | |
| print(f"rag_retrieve error: {e}") | |
| return [] | |
| def get_economic_chunks(texts: list, max_chunks: int = 40) -> list: | |
| n = len(texts) | |
| econ = [t for t in texts if any(kw in t.lower() for kw in ECON_TRIGGER)] | |
| if len(econ) < 10: | |
| start = texts[:min(10, n)] | |
| mid = texts[n//2-5 : n//2+5] if n > 20 else [] | |
| end = texts[-min(10, n):] | |
| econ = list(dict.fromkeys(start + mid + end)) | |
| if len(econ) > max_chunks: | |
| step = max(1, len(econ) // max_chunks) | |
| sample = econ[::step][:max_chunks] | |
| else: | |
| sample = econ | |
| return sample | |
| def llm_groq(question, rag_context, history, lang): | |
| system_prompt = ( | |
| "You are a smart multilingual AI assistant.\n" | |
| "- Always reply in the SAME language as the user question.\n" | |
| "- If Arabic reply fully in Arabic. If English reply fully in English.\n" | |
| "- Use document context precisely and cite page numbers.\n" | |
| "- If answer not in docs, use general knowledge and say so.\n" | |
| "- Be concise, helpful, accurate." | |
| ) | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for turn in history[-4:]: | |
| messages.append({"role": turn["role"], "content": turn["content"]}) | |
| user_content = f"📄 Context:\n{rag_context}\n\nQuestion: {question}" if rag_context else question | |
| messages.append({"role": "user", "content": user_content}) | |
| try: | |
| r = groq_client.chat.completions.create( | |
| model="llama-3.3-70b-versatile", | |
| messages=messages, | |
| temperature=0.3, | |
| max_tokens=512, | |
| ) | |
| return r.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"⚠️ Groq error: {e}" | |
| def smart_answer(question, history): | |
| lang = detect_lang(question) | |
| results = rag_retrieve(question, k=5, top_n=3) | |
| rag_context = "" | |
| if results: | |
| for r in results: | |
| rag_context += f"[Source: {r['file']} - Page {r['page']}]\n{r['text']}\n\n" | |
| has_good_rag = bool(results) and results[0]["sem"] >= 0.25 | |
| answer_text = llm_groq(question, rag_context[:2000], history, lang) | |
| if has_good_rag: | |
| src = ", ".join(f"`{r['file']}` p.{r['page']}" for r in results) | |
| badge = f"\n\n📄 **{'المصدر' if lang=='ar' else 'Source'}:** {src}" | |
| CHAT_STATS["found"] += 1 | |
| else: | |
| badge = f"\n\n_🤖 {'إجابة عامة.' if lang=='ar' else 'General knowledge answer.'}_" | |
| CHAT_STATS["not_found"] += 1 | |
| CHAT_STATS["questions"] += 1 | |
| return answer_text + badge, "rag" if has_good_rag else "llm" | |
| def predict_with_rag(text): | |
| text = "" if text is None else str(text).strip() | |
| if not text: raise gr.Error("⚠️ Enter text first.") | |
| lang = detect_lang(text) | |
| qterms = [t for t in re.findall(r"\w+", text.lower()) if len(t) > 2] | |
| exact_hits = [] | |
| for i, chunk in enumerate(KB_TEXTS): | |
| if not isinstance(chunk, str): continue | |
| cl = chunk.lower() | |
| for term in qterms: | |
| try: | |
| if re.search(rf"\b{re.escape(term)}\b", cl): | |
| for s in re.split(r"(?<=[.!?؟\n])\s+", chunk): | |
| if re.search(rf"\b{re.escape(term)}\b", s.lower()): | |
| exact_hits.append({ | |
| "word": term, "file": KB_META[i]["name"], | |
| "sentence": s.strip(), "lang": KB_META[i]["lang"], | |
| "chunk_id": i, "page": KB_META[i]["page"], | |
| }) | |
| except: continue | |
| sem_results, md = rag_retrieve(text, k=5, top_n=3), "" | |
| if exact_hits: | |
| seen, unique = set(), [] | |
| for h in exact_hits: | |
| key = (h["word"], h["file"], h["sentence"][:80]) | |
| if key not in seen: seen.add(key); unique.append(h) | |
| md += "## ✅ Word Found\n\n" | |
| for h in unique: | |
| flag = "🇸🇦" if h["lang"]=="ar" else "🇺🇸" | |
| md += f"- 🔑 **`{h['word']}`** → 📄 `{h['file']}` p.{h['page']} {flag}\n\n > {h['sentence']}\n\n" | |
| detail = run_sentiment_detailed(text) | |
| sent, conf = run_sentiment(text) | |
| md += f"---\n{detail}\n\n---\n## 📍 Exact Location\n\n" | |
| seen2 = set() | |
| for h in unique: | |
| k2 = (h["file"], h["chunk_id"]) | |
| if k2 in seen2: continue | |
| seen2.add(k2) | |
| md += f"### 📄 `{h['file']}` — p.{h['page']} {'🇸🇦' if h['lang']=='ar' else '🇺🇸'}\n\n```\n{KB_TEXTS[h['chunk_id']]}\n```\n\n" | |
| else: | |
| sent, conf = "❌ Not found", 0.0 | |
| if lang == "ar": | |
| md += f"## ❌ الكلمة غير موجودة\n\n**`{text}`** لم تُذكر حرفياً.\n\n" | |
| else: | |
| md += f"## ❌ Word Not Found\n\n**`{text}`** not found literally.\n\n" | |
| if sem_results: | |
| md += "---\n## 🔍 Semantic Results\n\n" | |
| for r in sem_results: | |
| bar = "🟩"*round(r["sem"]*10) + "⬜"*(10-round(r["sem"]*10)) | |
| snippet = r["text"][:300].strip() | |
| for t in qterms: | |
| try: snippet = re.sub(rf"(?i)({re.escape(t)})", r"**\1**", snippet) | |
| except: pass | |
| md += ( | |
| f"### Result {r['rank']} — {bar} `{r['sem']*100:.1f}%` " | |
| f"{'🇸🇦' if r['lang']=='ar' else '🇺🇸'}\n\n" | |
| f"📄 `{r['file']}` p.{r['page']}\n\n> {snippet}...\n\n" | |
| ) | |
| else: | |
| md += "---\n_No similar content found._\n" | |
| return sent, round(conf, 4), md | |
| # ============================================================ | |
| # ECONOMETRICS — World Bank + ARIMA/SARIMAX | |
| # ============================================================ | |
| def get_worldbank_data(country_code, indicator, start_year, end_year): | |
| url = ( | |
| f"https://api.worldbank.org/v2/country/{country_code}/" | |
| f"indicator/{indicator}?date={start_year}:{end_year}&per_page=100&format=json" | |
| ) | |
| try: | |
| resp = requests.get(url, timeout=15) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| if not data or len(data) < 2 or not data[1]: return pd.DataFrame() | |
| rows = [ | |
| {"year": int(e["date"]), "value": float(e["value"])} | |
| for e in data[1] | |
| if e.get("value") is not None and e.get("date") is not None | |
| ] | |
| return pd.DataFrame(rows).dropna().sort_values("year").reset_index(drop=True) | |
| except Exception as e: | |
| print(f"World Bank error: {e}") | |
| return pd.DataFrame() | |
| def build_doc_sentiment_index(): | |
| if not KB_TEXTS or not KB_META: return None, None | |
| files_texts = {} | |
| for text, meta in zip(KB_TEXTS, KB_META): | |
| files_texts.setdefault(meta["name"], []).append(text[:400]) | |
| yearly_sentiment, file_results = {}, [] | |
| for fname, texts in files_texts.items(): | |
| sample = get_economic_chunks(texts, max_chunks=40) | |
| scores = [sentiment_score_numeric(t) for t in sample] | |
| avg = round(float(np.mean(scores)), 4) | |
| year = next( | |
| (m["year"] for m in KB_META if m["name"]==fname and m.get("year")), None | |
| ) | |
| file_results.append({ | |
| "file": fname, "year": year if year else "N/A", | |
| "sentiment": avg, "n_chunks": len(sample), | |
| "label": "🟢 Optimistic" if avg > 0.05 else "🔴 Pessimistic" if avg < -0.05 else "🟡 Neutral", | |
| }) | |
| if year: | |
| yearly_sentiment.setdefault(year, []).append(avg) | |
| yearly_avg = { | |
| yr: round(float(np.mean(vals)), 4) | |
| for yr, vals in yearly_sentiment.items() | |
| } | |
| df_files = pd.DataFrame(file_results).sort_values("year") | |
| df_yearly = ( | |
| pd.DataFrame([{"year": y, "sentiment": s} for y, s in sorted(yearly_avg.items())]) | |
| if yearly_avg else None | |
| ) | |
| return df_files, df_yearly | |
| def run_adf_check(series: np.ndarray, name: str): | |
| from statsmodels.tsa.stattools import adfuller | |
| def adf_p(s): | |
| try: return adfuller(s, autolag='AIC')[1] | |
| except: return 1.0 | |
| s = series.copy() | |
| p0 = adf_p(s) | |
| if p0 <= 0.05: | |
| return s, f"✅ Stationary at level (p={p0:.4f})", False | |
| s1 = np.diff(s) | |
| p1 = adf_p(s1) | |
| if p1 <= 0.05: | |
| return s1, f"⚠️ Non-stationary (p={p0:.4f}) → 1st diff → ✅ stationary (p={p1:.4f})", True | |
| s2 = np.diff(s1) | |
| p2 = adf_p(s2) | |
| return ( | |
| s2, | |
| f"⚠️ Non-stationary (p={p0:.4f}) → 1st diff (p={p1:.4f}) → 2nd diff → " | |
| f"{'✅ stationary' if p2<=0.05 else '⚠️ non-stationary'} (p={p2:.4f})", | |
| True, | |
| ) | |
| def run_granger_test(series_y, series_exog, maxlag=4): | |
| try: | |
| from statsmodels.tsa.stattools import grangercausalitytests | |
| if len(series_y) < 10: | |
| return "⚠️ **Granger Test skipped** — need ≥ 10 points.", False | |
| sy, status_y = run_adf_check(series_y.copy(), "Target")[:2] | |
| sexog, status_exog = run_adf_check(series_exog.copy(), "Sentiment")[:2] | |
| min_len = min(len(sy), len(sexog)) | |
| sy, sexog = sy[-min_len:], sexog[-min_len:] | |
| maxlag = min(maxlag, max(1, (len(sy) - 1) // 3)) | |
| if len(sy) < 5: | |
| return "⚠️ **Granger Test skipped** — too few obs after differencing.", False | |
| gc_result = grangercausalitytests( | |
| np.column_stack([sy, sexog]), maxlag=maxlag, verbose=False | |
| ) | |
| rows, any_pass, best_p = [], False, 1.0 | |
| for lag, res in gc_result.items(): | |
| p_val = res[0]["ssr_ftest"][1] | |
| f_val = res[0]["ssr_ftest"][0] | |
| if p_val < 0.05: sig = "✅ Yes"; any_pass = True | |
| elif p_val < 0.10: sig = "🔶 Marginal" | |
| else: sig = "❌ No" | |
| best_p = min(best_p, p_val) | |
| rows.append(f"| {lag} | {f_val:.4f} | {p_val:.4f} | {sig} |") | |
| table = ( | |
| "### 🔬 Granger Causality Test\n" | |
| "*H₀: Sentiment does NOT Granger-cause Target*\n\n" | |
| f"#### 📋 ADF Stationarity Pre-check\n\n" | |
| f"| Series | ADF Result |\n|---|---|\n" | |
| f"| 🎯 Target | {status_y} |\n" | |
| f"| 😊 Sentiment | {status_exog} |\n\n" | |
| "#### 📊 Granger Results\n\n" | |
| "| Lag | F-stat | p-value | Significant? |\n|-----|--------|---------|-------------|\n" | |
| + "\n".join(rows) | |
| ) | |
| if any_pass: | |
| verdict = f"\n\n✅ **PASS** — Sentiment significantly Granger-causes the target (p < 0.05)." | |
| elif best_p < 0.10: | |
| verdict = f"\n\n🔶 **MARGINAL** — best p = {best_p:.4f} (< 0.10)." | |
| else: | |
| verdict = "\n\n❌ **FAIL** — No significant Granger causality (p ≥ 0.05)." | |
| return table + verdict, any_pass | |
| except Exception as e: | |
| return f"⚠️ Granger test error: `{e}`\n", False | |
| def run_dm_test(actual, pred_arima, pred_sarimax): | |
| try: | |
| n = len(actual) | |
| if n < 3: | |
| return "⚠️ **DM Test skipped** — n < 3.", False | |
| d = (actual - pred_arima)**2 - (actual - pred_sarimax)**2 | |
| d_mean = np.mean(d) | |
| d_std = np.std(d, ddof=1) | |
| if d_std < 1e-10: | |
| return "⚠️ **DM Test** — models identical.", False | |
| dm_stat = d_mean / (d_std / np.sqrt(n)) | |
| p_val = 2 * (1 - stats.t.cdf(abs(dm_stat), df=n - 1)) | |
| sig = "✅ Yes" if p_val < 0.05 else ("🔶 Marginal" if p_val < 0.10 else "❌ No") | |
| better = "SARIMAX+Ensemble" if dm_stat > 0 else "ARIMA" | |
| table = ( | |
| "### 🎯 Diebold-Mariano Test\n" | |
| "*H₀: Equal predictive accuracy | H₁: SARIMAX better than ARIMA*\n\n" | |
| "| DM Statistic | p-value | n (test) | Significant? | Better Model |\n" | |
| "|-------------|---------|----------|-------------|-------------|\n" | |
| f"| `{dm_stat:.4f}` | `{p_val:.4f}` | `{n}` | {sig} | **{better}** |\n" | |
| ) | |
| passed = p_val < 0.05 and dm_stat > 0 | |
| if passed: | |
| verdict = "\n✅ **PASS** — SARIMAX+Ensemble is **significantly better** (p < 0.05)." | |
| elif (p_val < 0.10) and dm_stat > 0: | |
| verdict = f"\n🔶 **MARGINAL** — p = {p_val:.4f} (< 0.10)." | |
| else: | |
| verdict = ( | |
| f"\n❌ **FAIL** — Not statistically significant (p = {p_val:.4f}).\n\n" | |
| f"> 💡 With n = {n} test points, power is limited. " | |
| f"Expand Start Year to 1990 for more test data." | |
| ) | |
| return table + verdict, passed | |
| except Exception as e: | |
| return f"⚠️ DM error: `{e}`\n", False | |
| # ============================================================ | |
| # MAIN FORECAST FUNCTION — n = 3 | |
| # ============================================================ | |
| def run_economic_forecast(country_code, target_var, start_year, end_year): | |
| try: | |
| from statsmodels.tsa.arima.model import ARIMA | |
| from statsmodels.tsa.statespace.sarimax import SARIMAX | |
| from sklearn.metrics import mean_squared_error, mean_absolute_error | |
| except ImportError: | |
| return "❌ pip install statsmodels scikit-learn", None | |
| indicator_map = { | |
| "Inflation (CPI %)" : "FP.CPI.TOTL.ZG", | |
| "GDP Growth (%) ": "NY.GDP.MKTP.KD.ZG", | |
| "Unemployment (%) ": "SL.UEM.TOTL.ZS", | |
| "Exchange Rate" : "PA.NUS.FCRF", | |
| } | |
| econ_df = get_worldbank_data( | |
| country_code, | |
| indicator_map.get(target_var, "FP.CPI.TOTL.ZG"), | |
| int(start_year), int(end_year), | |
| ) | |
| if econ_df.empty: | |
| return f"❌ No data for **{country_code}** / **{target_var}**", None | |
| if len(econ_df) < 5: | |
| return f"⚠️ Only **{len(econ_df)}** data points. Widen year range.", None | |
| df_files, df_yearly = build_doc_sentiment_index() | |
| if df_yearly is not None and len(df_yearly) >= 2: | |
| merged = econ_df.merge(df_yearly, on="year", how="left") | |
| merged["sentiment"] = merged["sentiment"].fillna( | |
| float(df_yearly["sentiment"].mean()) | |
| ) | |
| has_yearly = True | |
| mode_msg = "✅ **Yearly Ensemble Sentiment**" | |
| else: | |
| global_sent = ( | |
| float(pd.to_numeric(df_files["sentiment"], errors="coerce").mean()) | |
| if df_files is not None and len(df_files) > 0 else 0.0 | |
| ) | |
| merged = econ_df.copy() | |
| merged["sentiment"] = global_sent | |
| has_yearly = False | |
| mode_msg = "⚠️ **Global Sentiment**" | |
| if merged["sentiment"].std() > 1e-6: | |
| scaler = MinMaxScaler(feature_range=(-0.3, 0.3)) | |
| merged["sentiment"] = scaler.fit_transform( | |
| merged["sentiment"].values.reshape(-1, 1) | |
| ).flatten().round(4) | |
| series = merged["value"].values.astype(float) | |
| exog = merged["sentiment"].values.reshape(-1, 1) | |
| years = merged["year"].values | |
| n = len(series) | |
| # ============================================================ | |
| # ✅ n = 3 — Test on last 3 years | |
| # ============================================================ | |
| split = n - 3 | |
| if split < 5: | |
| split = max(int(n * 0.75), 5) # safety fallback for very short series | |
| train_y, test_y = series[:split], series[split:] | |
| train_exog, test_exog = exog[:split], exog[split:] | |
| test_years = years[split:] | |
| # ARIMA baseline | |
| try: | |
| m1 = ARIMA(train_y, order=(1,1,1)).fit() | |
| pred_arima = m1.forecast(len(test_y)) | |
| rmse_a = float(np.sqrt(mean_squared_error(test_y, pred_arima))) | |
| mae_a = float(mean_absolute_error(test_y, pred_arima)) | |
| mape_a = float(np.mean(np.abs((test_y-pred_arima)/np.maximum(np.abs(test_y),1e-8)))*100) | |
| except Exception as e: | |
| return f"❌ ARIMA error: {e}", None | |
| # SARIMAX + Ensemble Sentiment | |
| try: | |
| m2 = SARIMAX(train_y, exog=train_exog, order=(1,1,1)).fit(disp=False) | |
| pred_sarimax = m2.forecast(len(test_y), exog=test_exog) | |
| rmse_s = float(np.sqrt(mean_squared_error(test_y, pred_sarimax))) | |
| mae_s = float(mean_absolute_error(test_y, pred_sarimax)) | |
| mape_s = float(np.mean(np.abs((test_y-pred_sarimax)/np.maximum(np.abs(test_y),1e-8)))*100) | |
| except Exception as e: | |
| return f"❌ SARIMAX error: {e}", None | |
| impr_rmse = (rmse_a - rmse_s) / rmse_a * 100 | |
| impr_mae = (mae_a - mae_s) / mae_a * 100 | |
| impr_mape = (mape_a - mape_s) / mape_a * 100 | |
| # Granger — use full series | |
| if has_yearly and df_yearly is not None and len(df_yearly) >= 5: | |
| real_merged = econ_df.merge(df_yearly, on="year", how="inner") | |
| gc_y = real_merged["value"].values.astype(float) | |
| gc_exog = real_merged["sentiment"].values.astype(float) | |
| else: | |
| gc_y = series | |
| gc_exog = merged["sentiment"].values | |
| granger_md, granger_pass = run_granger_test(gc_y, gc_exog, maxlag=4) | |
| dm_md, dm_pass = run_dm_test(test_y, np.array(pred_arima), np.array(pred_sarimax)) | |
| # ============================================================ | |
| # PLOTS | |
| # ============================================================ | |
| fig, axes = plt.subplots(4, 1, figsize=(11, 18)) | |
| # Plot 1 — Forecast | |
| axes[0].plot(years, series, "o-", color="#2196F3", label="Actual", lw=2, ms=5) | |
| axes[0].plot(test_years, pred_arima, "s--", color="#FF5722", label="ARIMA(1,1,1)", lw=2) | |
| axes[0].plot(test_years, pred_sarimax, "^-.", color="#4CAF50", label="SARIMAX+Ensemble", lw=2) | |
| axes[0].axvline(x=years[split-1], color="gray", linestyle=":", alpha=0.7, label="Train│Test") | |
| axes[0].set_title( | |
| f"📈 {target_var} — {country_code} (Yearly Ensemble Sentiment) | n_test={len(test_y)}", | |
| fontsize=11, fontweight="bold", | |
| ) | |
| axes[0].set_xlabel("Year"); axes[0].set_ylabel(target_var) | |
| axes[0].legend(fontsize=9); axes[0].grid(True, alpha=0.3) | |
| # Plot 2 — Sentiment Index | |
| s_clrs = [ | |
| "#4CAF50" if s > 0.05 else "#FF5722" if s < -0.05 else "#FFC107" | |
| for s in merged["sentiment"] | |
| ] | |
| axes[1].bar(years, merged["sentiment"], color=s_clrs, edgecolor="white", width=0.6) | |
| axes[1].axhline(y=0, color="black", lw=0.8) | |
| legend_patches = [ | |
| Patch(color="#4CAF50", label="Optimistic (>0.05)"), | |
| Patch(color="#FFC107", label="Neutral"), | |
| Patch(color="#FF5722", label="Pessimistic (<-0.05)"), | |
| ] | |
| axes[1].legend(handles=legend_patches, fontsize=8, loc="upper right") | |
| axes[1].set_title( | |
| "📊 Ensemble Sentiment Index (FinBERT 40% + XLM 30% + Lexicon 30%)\n" | |
| "per-year — normalized [-0.3, +0.3]", | |
| fontsize=10, fontweight="bold", | |
| ) | |
| axes[1].set_xlabel("Year"); axes[1].set_ylabel("Sentiment Score") | |
| axes[1].grid(True, alpha=0.3, axis="y") | |
| # Plot 3 — RMSE Bar | |
| better_color_a = "#4CAF50" if rmse_a <= rmse_s else "#4CAF50" | |
| better_color_s = "#4CAF50" if rmse_s <= rmse_a else "#4CAF50" | |
| bar_colors = ["#FF5722" if rmse_a > rmse_s else "#4CAF50", | |
| "#4CAF50" if rmse_s <= rmse_a else "#FF5722"] | |
| bars = axes[2].bar( | |
| ["ARIMA(1,1,1)", "SARIMAX\n+Ensemble"], | |
| [rmse_a, rmse_s], color=bar_colors, width=0.4, edgecolor="white", | |
| ) | |
| for bar, val in zip(bars, [rmse_a, rmse_s]): | |
| axes[2].text( | |
| bar.get_x()+bar.get_width()/2, bar.get_height()+0.01, | |
| f"{val:.4f}", ha="center", va="bottom", fontweight="bold", fontsize=11, | |
| ) | |
| axes[2].set_title("📉 RMSE Comparison (lower = better)", fontsize=11) | |
| axes[2].set_ylabel("RMSE"); axes[2].grid(True, alpha=0.3, axis="y") | |
| # Plot 4 — Statistical Tests Summary Table | |
| axes[3].axis("off") | |
| test_data = [ | |
| ["Test", "Result", "Interpretation"], | |
| [ | |
| "Granger (ADF + Granger)", | |
| "✅ PASS" if granger_pass else "❌ FAIL", | |
| "Sentiment Granger-causes Target" if granger_pass else "No causal link detected", | |
| ], | |
| [ | |
| "Diebold-Mariano\n(SARIMAX vs ARIMA)", | |
| "✅ PASS" if dm_pass else "❌ FAIL", | |
| "SARIMAX significantly better" if dm_pass else f"n_test={len(test_y)} — limited power", | |
| ], | |
| ] | |
| tbl4 = axes[3].table( | |
| cellText=test_data[1:], colLabels=test_data[0], | |
| cellLoc="center", loc="center", colWidths=[0.35, 0.2, 0.45], | |
| ) | |
| tbl4.auto_set_font_size(False); tbl4.set_fontsize(11); tbl4.scale(1, 2.5) | |
| for (row, col), cell in tbl4.get_celld().items(): | |
| if row == 0: | |
| cell.set_facecolor("#1565C0") | |
| cell.set_text_props(color="white", fontweight="bold") | |
| elif row == 1: | |
| cell.set_facecolor("#E8F5E9" if granger_pass else "#FFEBEE") | |
| elif row == 2: | |
| cell.set_facecolor("#E8F5E9" if dm_pass else "#FFEBEE") | |
| axes[3].set_title( | |
| "🔬 Statistical Tests: ADF + Granger + DM", | |
| fontsize=12, fontweight="bold", pad=20, | |
| ) | |
| plt.tight_layout(pad=3.0) | |
| img_path = "/tmp/forecast_plot.png" | |
| plt.savefig(img_path, dpi=130, bbox_inches="tight") | |
| plt.close(fig) | |
| # ============================================================ | |
| # RESULT TEXT | |
| # ============================================================ | |
| sent_table = "" | |
| if df_files is not None and len(df_files) > 0: | |
| sent_table = ( | |
| "\n---\n### 📄 Ensemble Sentiment per File\n" | |
| "| 📄 File | 📅 Year | 😊 Score | 📦 Chunks | Label |\n|---|---|---|---|---|\n" | |
| ) | |
| for _, row in df_files.iterrows(): | |
| sent_table += ( | |
| f"| `{row['file']}` | {row['year']} | " | |
| f"`{row['sentiment']:+.4f}` | {row['n_chunks']} | {row['label']} |\n" | |
| ) | |
| result_md = ( | |
| f"## 📊 Forecast — {country_code} / {target_var}\n\n" | |
| f"| | |\n|---|---|\n" | |
| f"| 🎯 Target Variable | **{target_var}** |\n" | |
| f"| 📈 Sentiment Mode | {mode_msg} |\n" | |
| f"| 📈 Train samples | **{split}** |\n" | |
| f"| 🧪 Test samples (n)| **{len(test_y)}** |\n\n" | |
| f"---\n### 🏆 Model Comparison\n" | |
| f"| Model | RMSE | MAE | MAPE |\n|---|---|---|---|\n" | |
| f"| ARIMA(1,1,1) | `{rmse_a:.4f}` | `{mae_a:.4f}` | `{mape_a:.1f}%` |\n" | |
| f"| SARIMAX+Ensemble | `{rmse_s:.4f}` | `{mae_s:.4f}` | `{mape_s:.1f}%` |\n" | |
| f"| **Improvement** | **{impr_rmse:+.1f}%** | **{impr_mae:+.1f}%** | **{impr_mape:+.1f}%** |\n\n" | |
| f"{'✅ **Improved** by adding Ensemble Sentiment Index.' if impr_rmse > 0 else '⚠️ No RMSE improvement for this variable.'}\n\n" | |
| f"---\n{granger_md}\n\n---\n{dm_md}\n{sent_table}" | |
| ) | |
| return result_md, img_path | |
| # ============================================================ | |
| # UTILITIES | |
| # ============================================================ | |
| def generate_report(text, sent, conf, md): | |
| path = "/tmp/report.md" | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(f"# Report\n\n**Input:** {text}\n**Sentiment:** {sent}\n\n{md}") | |
| return path | |
| def export_chat(history): | |
| path = "/tmp/chat.txt" | |
| with open(path, "w", encoding="utf-8") as f: | |
| for turn in history: | |
| f.write(f"{turn['role']}:\n{turn['content']}\n\n") | |
| return path | |
| def get_stats(): | |
| return ( | |
| f"### 📊 Session Stats\n\n" | |
| f"| | |\n|---|---|\n" | |
| f"| ❓ Questions asked | **{CHAT_STATS['questions']}** |\n" | |
| f"| ✅ RAG answers | **{CHAT_STATS['found']}** |\n" | |
| f"| 🤖 General answers | **{CHAT_STATS['not_found']}** |\n" | |
| f"| 📦 Chunks indexed | **{len(KB_TEXTS):,}** |\n" | |
| ) | |
| def get_top_keywords(): | |
| if not KB_TEXTS: return "_No files uploaded yet._" | |
| stopwords = {"this","that","with","from","have","been","were","they","their", | |
| "there","what","when","which","will","also","than","into","more"} | |
| top = Counter( | |
| w for w in re.findall(r"\b\w{4,}\b", " ".join(KB_TEXTS).lower()) | |
| if w not in stopwords | |
| ).most_common(20) | |
| return "### 🔑 Top 20 Keywords\n\n" + "\n".join(f"- **{w}**: {c}" for w,c in top) | |
| def update_threshold(val): | |
| global MIN_SIMILARITY | |
| MIN_SIMILARITY = val | |
| return f"✅ Threshold set to: {val:.0%}" | |
| def chat_text(message, history): | |
| if not message.strip(): return "", history | |
| answer, _ = smart_answer(message, history) | |
| return "", history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": answer}, | |
| ] | |
| def tts_save(text, lang="en"): | |
| path = "/tmp/ans.mp3" | |
| gTTS( | |
| text=re.sub(r"[*`#>\[\]|_]", "", text)[:600], | |
| lang="ar" if lang == "ar" else "en", | |
| ).save(path) | |
| return path | |
| def chat_voice(audio, history): | |
| if audio is None: raise gr.Error("No audio received.") | |
| sr, y = audio | |
| y = np.array(y) if isinstance(y, list) else y | |
| if y.ndim > 1: y = y.mean(axis=1) | |
| transcript = asr({"array": y.astype(np.float32), "sampling_rate": sr})["text"] | |
| lang = detect_lang(transcript) | |
| answer, _ = smart_answer(transcript, history) | |
| new_history = history + [ | |
| {"role": "user", "content": f"🎙️ {transcript}"}, | |
| {"role": "assistant", "content": answer}, | |
| ] | |
| return new_history, tts_save(answer, lang), transcript | |
| # ============================================================ | |
| # GRADIO UI | |
| # ============================================================ | |
| with gr.Blocks(title="RAG + Sentiment + Forecast", theme=gr.themes.Soft()) as app: | |
| gr.Markdown( | |
| "# 🤖 Hybrid Multilingual RAG + Ensemble Sentiment + Economic Forecast\n" | |
| "**ENSSEA — Master's Thesis | Si Tayeb Houari | 2025–2026**" | |
| ) | |
| # ── Tab 1: Upload ───────────────────────────────────────── | |
| with gr.Tab("📁 1 · Upload"): | |
| files = gr.File( | |
| label="📂 Upload Files (PDF / TXT / CSV / DOCX)", | |
| file_types=[".pdf",".txt",".csv",".docx"], | |
| file_count="multiple", type="filepath", | |
| ) | |
| build_btn = gr.Button("🔨 Build Index", variant="primary") | |
| status = gr.Markdown("_No index built yet._") | |
| with gr.Row(): | |
| save_btn = gr.Button("💾 Save Index") | |
| load_btn = gr.Button("🔄 Load Saved Index") | |
| persist_status = gr.Markdown() | |
| sim_slider = gr.Slider(0.0, 1.0, value=0.10, step=0.05, label="🎯 Similarity Threshold") | |
| threshold_status = gr.Markdown() | |
| build_btn.click(build_index, inputs=files, outputs=status) | |
| save_btn.click(save_index, outputs=persist_status) | |
| load_btn.click(load_saved_index, outputs=persist_status) | |
| sim_slider.change(update_threshold, inputs=sim_slider, outputs=threshold_status) | |
| # ── Tab 2: Sentiment & Word Search ──────────────────────── | |
| with gr.Tab("🎭 2 · Sentiment & Search"): | |
| inp = gr.Textbox(lines=3, label="📝 Enter text or keyword") | |
| run_btn = gr.Button("🔍 Analyze & Search", variant="primary") | |
| with gr.Row(): | |
| out_sent = gr.Textbox(label="🎭 Sentiment") | |
| out_conf = gr.Number(label="📊 Score") | |
| out_full = gr.Markdown() | |
| rep_btn = gr.Button("📄 Download Report") | |
| rep_file = gr.File(label="📥 Report") | |
| run_btn.click(predict_with_rag, inputs=inp, outputs=[out_sent, out_conf, out_full]) | |
| rep_btn.click(generate_report, inputs=[inp, out_sent, out_conf, out_full], outputs=rep_file) | |
| # ── Tab 3: Smart Chatbot ────────────────────────────────── | |
| with gr.Tab("💬 3 · Smart Chatbot"): | |
| chatbot = gr.Chatbot(height=430, type="messages", show_label=False) | |
| msg = gr.Textbox(placeholder="Ask anything about your documents…", label="💬 Message") | |
| with gr.Row(): | |
| send_btn = gr.Button("📨 Send", variant="primary") | |
| clear_btn = gr.Button("🗑️ Clear") | |
| exp_btn = gr.Button("📥 Export") | |
| exp_file = gr.File(label="💾 Chat Export") | |
| msg.submit(chat_text, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
| send_btn.click(chat_text, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
| clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
| exp_btn.click(export_chat, inputs=chatbot, outputs=exp_file) | |
| # ── Tab 4: Voice ────────────────────────────────────────── | |
| with gr.Tab("🎙️ 4 · Voice"): | |
| gr.Markdown("### 🎙️ Speak your question — get a spoken answer") | |
| voice_input = gr.Audio(sources=["microphone"], type="numpy", label="🎤 Record") | |
| voice_btn = gr.Button("🎙️ Ask by Voice", variant="primary") | |
| voice_chat = gr.Chatbot(height=300, type="messages") | |
| audio_output = gr.Audio(label="🔊 Answer", autoplay=True) | |
| transcript_out= gr.Textbox(label="📝 Transcript") | |
| voice_btn.click(chat_voice, inputs=[voice_input, voice_chat], | |
| outputs=[voice_chat, audio_output, transcript_out]) | |
| # ── Tab 5: Analytics ───────────────────────────────────── | |
| with gr.Tab("📊 5 · Analytics"): | |
| stats_btn = gr.Button("📊 Refresh Stats") | |
| stats_out = gr.Markdown() | |
| kw_btn = gr.Button("🔑 Top Keywords") | |
| kw_out = gr.Markdown() | |
| stats_btn.click(get_stats, outputs=stats_out) | |
| kw_btn.click(get_top_keywords, outputs=kw_out) | |
| # ── Tab 6: About ────────────────────────────────────────── | |
| with gr.Tab("ℹ️ 6 · About"): | |
| gr.Markdown( | |
| "## 🤖 Hybrid Multilingual RAG Framework\n\n" | |
| "| Component | Details |\n|---|---|\n" | |
| "| 🏫 School | ENSSEA — École Nationale Supérieure de Statistique et d'Économie Appliquée |\n" | |
| "| 👤 Author | Si Tayeb Houari |\n" | |
| "| 📅 Year | 2025–2026 |\n" | |
| "| 🎓 Degree | Master's — Statistics & Foresight Economics |\n\n" | |
| "### 🔧 Models Used\n" | |
| "- 🏦 **FinBERT** (ProsusAI) — Financial sentiment (40%)\n" | |
| "- 🌍 **XLM-RoBERTa** (CardiffNLP) — Multilingual sentiment (30%)\n" | |
| "- 📖 **Economic Lexicon** — Domain-specific keywords (30%)\n" | |
| "- 🔍 **MiniLM-L12** — Multilingual embeddings (FAISS)\n" | |
| "- 📊 **ms-marco-MiniLM** — Cross-encoder reranking\n" | |
| "- 🗣️ **Whisper-small** — ASR\n" | |
| "- 🤖 **Llama-3.3-70B** via Groq — Response generation\n\n" | |
| "### 📊 Forecasting\n" | |
| "- Baseline: **ARIMA(1,1,1)**\n" | |
| "- Enhanced: **SARIMAX + Ensemble Sentiment** (n_test = 3)\n" | |
| "- Tests: **ADF**, **Granger Causality**, **Diebold-Mariano**\n" | |
| "- Data: **World Bank API**\n" | |
| ) | |
| # ── Tab 7: Economic Forecast ────────────────────────────── | |
| with gr.Tab("📈 7 · Forecast"): | |
| gr.Markdown( | |
| "## 📈 Economic Forecast — ARIMA vs SARIMAX + Ensemble Sentiment\n" | |
| "> **n_test = 3** — Evaluates on the last 3 years (captures recent economic turbulence)" | |
| ) | |
| with gr.Row(): | |
| country_input = gr.Textbox( | |
| value="DZ", label="🌍 Country Code (ISO)", | |
| placeholder="e.g. DZ, MA, TN, EG, US", | |
| ) | |
| target_input = gr.Dropdown( | |
| choices=[ | |
| "Inflation (CPI %)", | |
| "GDP Growth (%) ", | |
| "Unemployment (%) ", | |
| "Exchange Rate", | |
| ], | |
| value="Inflation (CPI %)", | |
| label="🎯 Target Variable", | |
| ) | |
| with gr.Row(): | |
| start_year = gr.Slider( | |
| minimum=1990, maximum=2020, value=2000, step=1, label="📅 Start Year" | |
| ) | |
| end_year = gr.Slider( | |
| minimum=2010, maximum=2024, value=2023, step=1, label="📅 End Year" | |
| ) | |
| forecast_btn = gr.Button("📈 Run Forecast", variant="primary", size="lg") | |
| forecast_result = gr.Markdown() | |
| forecast_plot = gr.Image(label="📊 Forecast Chart", type="filepath") | |
| forecast_btn.click( | |
| run_economic_forecast, | |
| inputs=[country_input, target_input, start_year, end_year], | |
| outputs=[forecast_result, forecast_plot], | |
| ) | |
| app.launch(server_name="0.0.0.0", server_port=7860, show_api=False) | |