tecuhtli commited on
Commit
afe0be6
·
verified ·
1 Parent(s): a6f140a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  #***************************************************************************
2
- #Importing Libraries
3
  #***************************************************************************
4
  import os, sys, warnings, torch, json, csv, warnings, joblib, uuid, re, unicodedata, faiss
5
  import numpy as np
@@ -11,11 +11,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSeque
11
  from unidecode import unidecode
12
  from datetime import datetime
13
  from huggingface_hub import hf_hub_download, login
14
- #***************************************************************************
15
- #Defining default paths for the model to work
16
- #***************************************************************************
17
-
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  #***************************************************************************
21
  #Setting up variables
@@ -180,9 +176,20 @@ def sidebar_params():
180
  # En session_state:
181
  if "PROMPT_CASES" not in st.session_state:
182
  st.session_state.PROMPT_CASES = load_prompt_cases()
183
-
184
 
185
- st.subheader("🧾 Vista previa del Prompt")
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  if "last_prompt" in st.session_state and st.session_state["last_prompt"]:
188
  with st.expander("Mostrar prompt generado"):
@@ -818,9 +825,201 @@ def contextual_asnwer(question, label_classes, context_model, cont_tok,
818
  set_seeds(gen_params["seed"])
819
 
820
  if context == "social":
 
 
821
  return social_asnwer(question, soc_model, soc_tok, device, gen_params=gen_params, block_web=block_web), context
822
  else:
823
- return technical_asnwer(question, context, tec_model, tec_tok, device, gen_params=gen_params), context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
 
825
  #***************************************************************************
826
  # MAIN
@@ -900,11 +1099,10 @@ if __name__ == '__main__':
900
  )
901
 
902
  # 🧠 Guarda historial
903
- hora_actual = dt.datetime.now().isoformat()
904
-
905
  st.session_state.historial.append(("Tú", user_question, hora_actual))
906
 
907
- hora_actual = dt.datetime.now().isoformat()
908
  st.session_state.historial.append(("Mori", response, hora_actual))
909
 
910
  # 💾 Guarda conversación
 
1
  #***************************************************************************
2
+ # Importing Libraries
3
  #***************************************************************************
4
  import os, sys, warnings, torch, json, csv, warnings, joblib, uuid, re, unicodedata, faiss
5
  import numpy as np
 
11
  from unidecode import unidecode
12
  from datetime import datetime
13
  from huggingface_hub import hf_hub_download, login
14
+ from sentence_transformers import SentenceTransformer # RAG embeddings
 
 
 
 
15
 
16
  #***************************************************************************
17
  #Setting up variables
 
176
  # En session_state:
177
  if "PROMPT_CASES" not in st.session_state:
178
  st.session_state.PROMPT_CASES = load_prompt_cases()
 
179
 
180
+
181
+ st.markdown("---")
182
+ st.title("👀 RAG (Modelo Técnico)")
183
+ ss.setdefault("use_rag", True)
184
+ ss.setdefault("rag_k", 5)
185
+ ss.use_rag = st.checkbox("Usar RAG (técnico)", value=ss.use_rag,
186
+ help="Recupera evidencias de ./Vec_DataBase/mori.* y las cita en el prompt.")
187
+ ss.rag_k = st.slider("k evidencias", 3, 9, int(ss.rag_k),
188
+ help="https://huggingface.co/docs/transformers/en/model_doc/rag")
189
+
190
+
191
+ st.markdown("---")
192
+ st.title("🧾 Vista previa del Prompt")
193
 
194
  if "last_prompt" in st.session_state and st.session_state["last_prompt"]:
195
  with st.expander("Mostrar prompt generado"):
 
825
  set_seeds(gen_params["seed"])
826
 
827
  if context == "social":
828
+ # Nota: por resultados del análisis, RAG social no aporta (dataset muy redundante).
829
+ # Puedes activarlo en el futuro si amplías la diversidad.
830
  return social_asnwer(question, soc_model, soc_tok, device, gen_params=gen_params, block_web=block_web), context
831
  else:
832
+ # Técnico: si el usuario activó RAG, lo usamos
833
+ use_rag = st.session_state.get("use_rag", False)
834
+ if use_rag:
835
+ # Carga única de E5+FAISS (cache_resource)
836
+ dev_str = "cuda" if torch.cuda.is_available() else "cpu"
837
+ e5, index, metas = load_rag_assets(dev_str)
838
+ if e5 is None:
839
+ # Fallback si no se encuentra la base RAG
840
+ return technical_asnwer(question, context, tec_model, tec_tok, device, gen_params=gen_params), context
841
+ resp = technical_answer_rag(
842
+ question, tec_model, tec_tok, device, gen_params,
843
+ e5=e5, index=index, metas=metas,
844
+ k=st.session_state.get("rag_k", 5), sim_threshold=0.40
845
+ )
846
+ return resp, context
847
+ else:
848
+ return technical_asnwer(question, context, tec_model, tec_tok, device, gen_params=gen_params), context
849
+
850
+
851
+
852
+ # ============================
853
+ # RAG assets (carga única)
854
+ # ============================
855
+ @st.cache_resource
856
+ def load_rag_assets(device_str: str = "cpu"):
857
+ """Carga E5 + FAISS + metadatos desde ./Vec_DataBase con nombres mori.*"""
858
+ vdb_dir = Path("Vec_DataBase")
859
+ faiss_path = vdb_dir / "mori.faiss"
860
+ metas_path = vdb_dir / "mori_metas.json"
861
+
862
+ if not faiss_path.exists() or not metas_path.exists():
863
+ st.warning("⚠️ No se encontró la base RAG en ./Vec_DataBase (mori.faiss / mori_metas.json).")
864
+ return None, None, None
865
+
866
+ e5 = SentenceTransformer("intfloat/multilingual-e5-base", device=device_str)
867
+ index = faiss.read_index(str(faiss_path))
868
+ with open(metas_path, "r", encoding="utf-8") as f:
869
+ metas = json.load(f)
870
+ return e5, index, metas
871
+
872
+
873
+ def rag_retrieve(e5, index, metas, user_text: str, k: int = 5):
874
+ """Top-k por similitud coseno (IP + embeddings normalizados)."""
875
+ if e5 is None or index is None or metas is None or index.ntotal == 0:
876
+ return []
877
+ qv = e5.encode([f"query: {user_text}"], normalize_embeddings=True,
878
+ convert_to_numpy=True).astype("float32")
879
+ k = max(1, min(int(k), index.ntotal))
880
+ scores, idxs = index.search(qv, k)
881
+ out = []
882
+ for rank, (s, i) in enumerate(zip(scores[0], idxs[0]), 1):
883
+ if i == -1:
884
+ continue
885
+ m = metas[i]
886
+ out.append({
887
+ "rank": rank, "score": float(s),
888
+ "id": m.get("id",""),
889
+ "canonical_term": m.get("canonical_term",""),
890
+ "context": m.get("context",""),
891
+ "input": m.get("input",""),
892
+ "output": m.get("output",""),
893
+ })
894
+ return out
895
+
896
+
897
+ def _format_evidence(passages):
898
+ lines = []
899
+ for p in passages:
900
+ lines.append(
901
+ f"[{p['rank']}] term='{p['canonical_term']}' ctx='{p['context']}'\n"
902
+ f" Q: {p['input']}\n"
903
+ f" A: {p['output']}"
904
+ )
905
+ return "\n".join(lines)
906
+
907
+
908
+ def build_rag_prompt_technical(base_prompt: str, user_text: str, passages):
909
+ ev_lines = []
910
+ for p in passages:
911
+ ev_lines.append(
912
+ f"[{p['rank']}] term='{p.get('canonical_term','')}' ctx='{p.get('context','')}'\n"
913
+ f"input: {p.get('input','')}\n"
914
+ f"output: {p.get('output','')}"
915
+ )
916
+
917
+ ev_block = "\n".join(ev_lines)
918
+ rag_rules = (
919
+ "\n\n[ Modo RAG ]\n"
920
+ "- Usa EXCLUSIVAMENTE la información relevante de las evidencias.\n"
921
+ "- Si algo no aparece en las evidencias, dilo explícitamente.\n"
922
+ "- Cita las evidencias con [n] (ej. [1], [3]).\n"
923
+ )
924
+ return f"{base_prompt.strip()}\n{rag_rules}\nEVIDENCIAS:\n{ev_block}\n"
925
+
926
+
927
+ def get_bad_words_ids(tok):
928
+ bad = []
929
+ for sym in ["[", "]"]:
930
+ ids = tok.encode(sym, add_special_tokens=False) # p.ej. [784]
931
+ if ids and all(isinstance(t, int) and t >= 0 for t in ids):
932
+ bad.append(ids) # [[784]]
933
+ return bad
934
+
935
+
936
+ # --- FUNCIÓN ACTUALIZADA: Prompt Engineering + RAG en capas separadas ---
937
+ def technical_answer_rag(
938
+ question, tec_model, tec_tok, device, gen_params,
939
+ e5, index, metas, k=5, sim_threshold=0.40
940
+ ):
941
+ """Orquesta retrieval + (base_prompt de Prompt Engineering) + inyección RAG + generación."""
942
+ passages = rag_retrieve(e5, index, metas, question, k=k)
943
+ if not passages:
944
+ return "No encontré evidencias relevantes para responder con certeza. ¿Puedes dar más contexto?"
945
+
946
+ # 1) Prompt Engineering (ESTILO/ROL/PERSONA) → base_prompt
947
+ persona_name = (gen_params or {}).get("persona", st.session_state.get("persona", "Mori Normal"))
948
+ prompt_type = st.session_state.get("prompt_type", "Zero-shot")
949
+ base_prompt = build_prompt_from_cases( # <<-- tu función existente de Prompt Engineering
950
+ domain="technical",
951
+ prompt_type="Zero-shot",
952
+ persona=persona_name,
953
+ question=question,
954
+ context="RAG" # etiqueta informativa
955
+ )
956
+
957
+ # 2) RAG (CONTENIDO/EVIDENCIAS) → se inyecta SOBRE el base_prompt
958
+ prompt = build_rag_prompt_technical("", question, passages)
959
+
960
+ # 3) UI: guardar prompt y marcar baja similitud si aplica
961
+ max_sim = passages[0]["score"]
962
+ if max_sim < sim_threshold:
963
+ prompt = "⚠️ Baja similitud con la base; podría faltar contexto.\n\n" + prompt
964
+ st.session_state["last_prompt"] = prompt
965
+ st.session_state["just_generated"] = True
966
+
967
+ # 4) Generación
968
+ enc = tec_tok(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
969
+
970
+ bad_ids = get_bad_words_ids(tec_tok) # opcional; puedes quitarlo si quieres permitir corchetes libres
971
+
972
+ max_new = int(gen_params.get("max_new_tokens"))
973
+ min_new = int(gen_params.get("min_tokens"))
974
+ no_repeat = int(gen_params.get("no_repeat_ngram_size"))
975
+ rep_pen = float(gen_params.get("repetition_penalty"))
976
+ mode = gen_params.get("mode", "beam")
977
+
978
+ # IDs de control (por si el tokenizer no los trae definidos)
979
+ eos_id = tec_tok.eos_token_id or tec_tok.convert_tokens_to_ids("</s>")
980
+ pad_id = tec_tok.pad_token_id or eos_id
981
+
982
+ if mode == "sampling":
983
+ temperature = float(gen_params.get("temperature", 0.7))
984
+ top_p = float(gen_params.get("top_p", 0.9))
985
+ kwargs = dict(
986
+ do_sample=True, num_beams=1,
987
+ temperature=max(0.1, temperature),
988
+ top_p=min(1.0, max(0.5, top_p)),
989
+ max_new_tokens=max_new,
990
+ min_new_tokens=max(0, min_new),
991
+ no_repeat_ngram_size=no_repeat,
992
+ repetition_penalty=max(1.0, rep_pen),
993
+ eos_token_id=eos_id,
994
+ pad_token_id=pad_id,
995
+ )
996
+ else:
997
+ num_beams = max(2, int(gen_params.get("num_beams", 4)))
998
+ length_penalty = float(gen_params.get("length_penalty", 1.0))
999
+ kwargs = dict(
1000
+ do_sample=False, num_beams=num_beams, length_penalty=length_penalty,
1001
+ max_new_tokens=max_new,
1002
+ min_new_tokens=max(0, min_new),
1003
+ no_repeat_ngram_size=no_repeat,
1004
+ repetition_penalty=max(1.0, rep_pen),
1005
+ eos_token_id=eos_id,
1006
+ pad_token_id=pad_id,
1007
+ )
1008
+
1009
+ if bad_ids: # solo si existen; evita [[[...]]] y errores de validación
1010
+ kwargs["bad_words_ids"] = bad_ids
1011
+
1012
+ out_ids = tec_model.generate(**enc, **kwargs)
1013
+ text = tec_tok.decode(out_ids[0], skip_special_tokens=True)
1014
+
1015
+ if persona_name == "Mori Normal":
1016
+ text = truncate_sentences(text, max_sentences=1)
1017
+ text = polish_spanish(text)
1018
+
1019
+ st.session_state["last_response"] = text
1020
+ return text
1021
+
1022
+
1023
 
1024
  #***************************************************************************
1025
  # MAIN
 
1099
  )
1100
 
1101
  # 🧠 Guarda historial
1102
+ hora_actual = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
1103
  st.session_state.historial.append(("Tú", user_question, hora_actual))
1104
 
1105
+ hora_actual = dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1106
  st.session_state.historial.append(("Mori", response, hora_actual))
1107
 
1108
  # 💾 Guarda conversación