Malaji71 commited on
Commit
e1fc349
·
verified ·
1 Parent(s): e51eb58

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +18 -42
agent.py CHANGED
@@ -1,4 +1,4 @@
1
- # agent.py — AGENTE SEMÁNTICO CON PRESERVACIÓN DE ENTIDADES v2.1
2
  import os
3
  import time
4
  import logging
@@ -10,35 +10,24 @@ import faiss
10
  import spacy
11
  from spacy.lang.en import English
12
 
13
- # Configurar logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
- # Cargar modelo de spaCy (con descarga automática si falta)
18
  try:
19
  NLP = spacy.load("en_core_web_sm")
20
  logger.info("✅ spaCy 'en_core_web_sm' cargado.")
21
  except OSError:
22
- logger.info("📥 Descargando 'en_core_web_sm' (primera ejecución)...")
23
  from spacy.cli import download
24
  download("en_core_web_sm")
25
  NLP = spacy.load("en_core_web_sm")
26
  logger.info("✅ spaCy 'en_core_web_sm' descargado y cargado.")
27
  except Exception as e:
28
- logger.warning(f"⚠️ Error inesperado con spaCy: {e}. Usando tokenizer básico.")
29
  NLP = English()
30
  NLP.add_pipe("sentencizer")
31
 
32
  class ImprovedSemanticAgent:
33
- """
34
- 🧠 AGENTE SEMÁNTICO CON PRESERVACIÓN DE ENTIDADES v2.1
35
-
36
- ✅ Extrae entidades clave con spaCy (descarga automática si es necesario).
37
- ✅ Filtra ejemplos que no comparten entidades con el usuario.
38
- ✅ Sintetiza prompts nuevos (no copia).
39
- ✅ Usa índice FAISS desde disco.
40
- """
41
-
42
  def __init__(self):
43
  logger.info("🚀 Cargando modelo de embeddings (bge-small-en-v1.5)...")
44
  self.embedding_model = SentenceTransformer('BAAI/bge-small-en-v1.5')
@@ -59,7 +48,7 @@ class ImprovedSemanticAgent:
59
  try:
60
  return future.result(timeout=60)
61
  except FutureTimeoutError:
62
- return "❌ Timeout inicializando agente (más de 60s)"
63
  except Exception as e:
64
  return f"❌ Error: {str(e)}"
65
 
@@ -86,7 +75,7 @@ class ImprovedSemanticAgent:
86
  if len(chunk.text) > 2 and not all(t.is_stop for t in chunk):
87
  entities.add(chunk.lemma_.replace(" ", "_"))
88
  text_lower = text.lower()
89
- if "fire" in text_lower or "flame" in text_lower or "burning" in text_lower:
90
  entities.add("on_fire")
91
  if "ice" in text_lower or "frozen" in text_lower:
92
  entities.add("frozen")
@@ -112,43 +101,30 @@ class ImprovedSemanticAgent:
112
  query_embedding = query_embedding.astype('float32').reshape(1, -1)
113
  distances, indices = self.index.search(query_embedding, 5)
114
 
115
- user_entities = self._extract_core_entities(user_prompt)
116
- logger.info(f"🔑 Entidades clave del usuario: {user_entities}")
117
-
118
  candidates = []
119
- filtered_count = 0
120
  for idx in indices[0]:
121
- if idx >= len(self.indexed_examples):
122
- continue
123
- caption = self.indexed_examples[idx]['caption']
124
- if not user_entities:
125
- candidates.append(caption)
126
- continue
127
- caption_entities = self._extract_core_entities(caption)
128
- if user_entities & caption_entities:
129
- candidates.append(caption)
130
- else:
131
- caption_lower = caption.lower()
132
- literal_match = any(
133
- ent.replace("_", " ") in caption_lower or ent in caption_lower
134
- for ent in user_entities
135
- )
136
- if literal_match:
137
- candidates.append(caption)
138
- else:
139
- filtered_count += 1
140
-
141
- logger.info(f"🗂️ Recuperados: {len(candidates)} ejemplos útiles ({filtered_count} filtrados)")
142
-
143
  if not candidates:
144
  return self._structural_fallback(user_prompt, category), "🔧 Fallback estructural"
145
 
 
 
 
146
  user_words = set(user_prompt.lower().split())
147
  all_parts = []
 
148
  for caption in candidates:
149
  parts = [p.strip() for p in caption.split(',') if 8 <= len(p) <= 120]
150
  for part in parts:
151
  part_lower = part.lower()
 
 
 
 
 
 
 
152
  if len(set(part_lower.split()) - user_words) >= 2:
153
  all_parts.append(part)
154
 
 
1
+ # agent.py — AGENTE SEMÁNTICO CON PRESERVACIÓN DE ENTIDADES
2
  import os
3
  import time
4
  import logging
 
10
  import spacy
11
  from spacy.lang.en import English
12
 
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
 
16
  try:
17
  NLP = spacy.load("en_core_web_sm")
18
  logger.info("✅ spaCy 'en_core_web_sm' cargado.")
19
  except OSError:
20
+ logger.info("📥 Descargando 'en_core_web_sm'...")
21
  from spacy.cli import download
22
  download("en_core_web_sm")
23
  NLP = spacy.load("en_core_web_sm")
24
  logger.info("✅ spaCy 'en_core_web_sm' descargado y cargado.")
25
  except Exception as e:
26
+ logger.warning(f"⚠️ Error con spaCy: {e}. Usando tokenizer básico.")
27
  NLP = English()
28
  NLP.add_pipe("sentencizer")
29
 
30
  class ImprovedSemanticAgent:
 
 
 
 
 
 
 
 
 
31
  def __init__(self):
32
  logger.info("🚀 Cargando modelo de embeddings (bge-small-en-v1.5)...")
33
  self.embedding_model = SentenceTransformer('BAAI/bge-small-en-v1.5')
 
48
  try:
49
  return future.result(timeout=60)
50
  except FutureTimeoutError:
51
+ return "❌ Timeout inicializando agente"
52
  except Exception as e:
53
  return f"❌ Error: {str(e)}"
54
 
 
75
  if len(chunk.text) > 2 and not all(t.is_stop for t in chunk):
76
  entities.add(chunk.lemma_.replace(" ", "_"))
77
  text_lower = text.lower()
78
+ if "fire" in text_lower or "flame" in text_lower:
79
  entities.add("on_fire")
80
  if "ice" in text_lower or "frozen" in text_lower:
81
  entities.add("frozen")
 
101
  query_embedding = query_embedding.astype('float32').reshape(1, -1)
102
  distances, indices = self.index.search(query_embedding, 5)
103
 
 
 
 
104
  candidates = []
 
105
  for idx in indices[0]:
106
+ if idx < len(self.indexed_examples):
107
+ candidates.append(self.indexed_examples[idx]['caption'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  if not candidates:
109
  return self._structural_fallback(user_prompt, category), "🔧 Fallback estructural"
110
 
111
+ # 🔑 EXTRAER ENTIDADES DEL USUARIO
112
+ user_entities = self._extract_core_entities(user_prompt)
113
+ user_has_clothing = any("swimsuit" in e or "dress" in e or "suit" in e or "armor" in e for e in user_entities)
114
  user_words = set(user_prompt.lower().split())
115
  all_parts = []
116
+
117
  for caption in candidates:
118
  parts = [p.strip() for p in caption.split(',') if 8 <= len(p) <= 120]
119
  for part in parts:
120
  part_lower = part.lower()
121
+ part_entities = self._extract_core_entities(part)
122
+ part_has_clothing = any("coat" in e or "jacket" in e or "scarf" in e or "hood" in e or "sweater" in e or "parka" in e for e in part_entities)
123
+
124
+ # ❌ Saltar si hay conflicto de ropa
125
+ if user_has_clothing and part_has_clothing:
126
+ continue
127
+
128
  if len(set(part_lower.split()) - user_words) >= 2:
129
  all_parts.append(part)
130