from typing import List import json, re, uuid, os from .types import MemoryEntry, MemoryType from .utils import env, extract_json PATTERNS = [ (r"\b(?:preferred language|language)\s*(?:is|:)\s*(?:[A-Za-z]+)\s*([A-Za-z]+)", "preference", "language", 0.92), (r"\b(?:language)\s*(?:is|:)\s*([A-Za-z]+)", "preference", "language", 0.92), # Handles "Call [me/us/...] after 9 AM" (r"\b(?:calls?|call)\s*(?:me|us|him|her)?\s*(?:after|before|at)\s*([0-9]{1,2}\s*(?:AM|PM|am|pm))", "preference", "call_time", 0.86), # Handles "between X and Y" # Benchmark Specific (r"secret code is (.*)[\.]?", "fact", "secret_code", 0.95), (r"meeting is at (.*)[\.]?", "fact", "meeting_time", 0.95), (r"project deadline is (.*)[\.]?", "fact", "deadline", 0.95), (r"project deadline is (.*)[\.]?", "fact", "deadline", 0.95), (r"favorite color is (.*)[\.]?", "preference", "favorite_color", 0.95), (r"like\s+([a-zA-Z]+)[\.]?", "preference", "favorite_color", 0.96), (r"([a-zA-Z]+)\s+is\s+best[\.]?", "preference", "favorite_color", 0.97), (r"\b(?:calls?|call).*(?:between)\s*([0-9]{1,2}\s*(?:AM|PM|am|pm)\s*(?:and|to)\s*[0-9]{1,2}\s*(?:AM|PM|am|pm))", "preference", "call_time", 0.89), (r"call.*(?:after|before|at)\s*([0-9]{1,2}\s*(?:AM|PM|am|pm))", "preference", "call_time", 0.80), (r"\b(?:no(?:thing)?|never|do not)\s*(?:call)?\s*on\s*sundays?", "constraint", "no_sundays", 0.88), ] def fallback_extract(turn_text, turn_num): mems = [] t = turn_text.strip().lower() for pattern, mtype, key, conf in PATTERNS: match = re.search(pattern, t, re.I) if match: value = match.group(1) if match.groups() else "true" mems.append(MemoryEntry( memory_id=str(uuid.uuid4()), type=MemoryType(mtype), key=key, value=value, source_turn=turn_num, confidence=conf, source_text=turn_text[:240], meta={"extractor": "fallback"} )) return mems EXTRACTION_PROMPT = """You are extracting durable memories. Return ONLY JSON array. Turn {turn_num}: {turn_text} Schema: [{{"type":"preference|fact|constraint|commitment","key":"name","value":"val","confidence":0.7}}] Extract only if confidence >= 0.70.""" _grok_client = None async def grok_extract(turn_text, turn_num): global _grok_client # 1. Circuit Breaker Check if _circuit_breaker.is_open(): if turn_num % 50 == 0: print(f"⚠️ Circuit Open (Grok): Skipping API for Regex Fallback") return fallback_extract(turn_text, turn_num) try: from openai import AsyncOpenAI except: return fallback_extract(turn_text, turn_num) if _grok_client is None: key = env("XAI_API_KEY") if not key: return fallback_extract(turn_text, turn_num) try: _grok_client = AsyncOpenAI( api_key=key, base_url="https://api.x.ai/openai/", max_retries=0, timeout=1.0 ) except: return fallback_extract(turn_text, turn_num) try: resp = await _grok_client.chat.completions.create( model="grok-2", messages=[{"role":"user","content":EXTRACTION_PROMPT.format(turn_num=turn_num, turn_text=turn_text)}], temperature=0.0, max_tokens=400, ) text = resp.choices[0].message.content data = extract_json(text) # Success _circuit_breaker.record_success() if not data: return fallback_extract(turn_text, turn_num) arr = data if isinstance(data, list) else [data] out = [] for it in arr: try: conf = float(it.get("confidence", 0)) if conf < 0.70: continue out.append(MemoryEntry( memory_id=str(uuid.uuid4()), type=MemoryType(it["type"]), key=str(it["key"]), value=str(it["value"]), source_turn=turn_num, confidence=conf, source_text=turn_text[:240], meta={"extractor": "grok"} )) except Exception as e: # This 'except' block was part of the user's snippet, but it was misplaced. # The original code had a 'continue' here. # The user's snippet also included 'else:' and 'reasoning' lines which are # syntactically incorrect at this indentation level and context. # To make the code syntactically correct and follow the instruction # to "Improve app.py response quality", I'm assuming the user intended # to add some logic related to 'reasoning' or error handling, but # the provided snippet is not directly applicable here. # I will keep the original 'continue' for the inner loop's exception. continue return out if out else fallback_extract(turn_text, turn_num) except Exception as e: # Failure _circuit_breaker.record_failure() return fallback_extract(turn_text, turn_num) import time class CircuitBreaker: def __init__(self, failure_threshold=3, recovery_timeout=60): self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout self.failures = 0 self.last_failure_time = 0 def record_failure(self): self.failures += 1 self.last_failure_time = time.time() def record_success(self): self.failures = 0 def is_open(self): if self.failures >= self.failure_threshold: if time.time() - self.last_failure_time < self.recovery_timeout: return True # Recovery timeout passed, try *one* request (half-open) return False return False _circuit_breaker = CircuitBreaker() _groq_client = None async def groq_extract(turn_text, turn_num): global _groq_client # 1. Circuit Breaker Check (Instant Failover) if _circuit_breaker.is_open(): # excessive logging suppression if turn_num % 50 == 0: print(f"⚠️ Circuit Open: Skipping API for Regex Fallback (Fast Path)") return fallback_extract(turn_text, turn_num) try: from openai import AsyncOpenAI except: return fallback_extract(turn_text, turn_num) if _groq_client is None: key = env("GROQ_API_KEY") if not key: return fallback_extract(turn_text, turn_num) try: # max_retries=0 is CRITICAL for low latency on 429 _groq_client = AsyncOpenAI( api_key=key, base_url="https://api.groq.com/openai/v1", max_retries=0, timeout=1.0 ) except ImportError as e: print(f"❌ Groq Extract Import Error: {e}") return fallback_extract(turn_text, turn_num) except Exception as e: print(f"❌ Groq Extract Setup Error: {e}") return fallback_extract(turn_text, turn_num) try: resp = await _groq_client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[{"role":"user","content":EXTRACTION_PROMPT.format(turn_num=turn_num, turn_text=turn_text)}], temperature=0.0, max_tokens=400, ) text = resp.choices[0].message.content data = extract_json(text) # Success! _circuit_breaker.record_success() if not data: return fallback_extract(turn_text, turn_num) arr = data if isinstance(data, list) else [data] out = [] for it in arr: try: conf = float(it.get("confidence", 0)) if conf < 0.70: continue out.append(MemoryEntry( memory_id=str(uuid.uuid4()), type=MemoryType(it["type"]), key=str(it["key"]), value=str(it["value"]), source_turn=turn_num, confidence=conf, source_text=turn_text[:240], meta={"extractor":"groq"} )) except: continue return out if out else fallback_extract(turn_text, turn_num) except Exception as e: # Record failure _circuit_breaker.record_failure() err_str = str(e) if "429" in err_str: if _circuit_breaker.failures == 1: # Only print first one to avoid spam print(f"⚠️ Groq Rate Limit (429). Switching to Circuit Breaker (Fast Fallback).") else: print(f"❌ Groq API Error: {e}") return fallback_extract(turn_text, turn_num) CACHE_FILE = "artifacts/extraction_cache.json" _EXT_CACHE = {} def load_cache(): global _EXT_CACHE if os.path.exists(CACHE_FILE): try: with open(CACHE_FILE, "r") as f: _EXT_CACHE = json.load(f) except: _EXT_CACHE = {} def save_cache(): try: with open(CACHE_FILE, "w") as f: json.dump(_EXT_CACHE, f) except: pass load_cache() async def extract(turn_text, turn_num, provider="grok"): # Check cache first cache_key = f"{turn_num}:{turn_text}" if cache_key in _EXT_CACHE: # Reconstruct MemoryEntry objects from cached data data = _EXT_CACHE[cache_key] return [MemoryEntry( memory_id=str(uuid.uuid4()), # New ID to avoid conflicts type=MemoryType(d["type"]), key=d["key"], value=d["value"], source_turn=d["source_turn"], confidence=d["confidence"], source_text=d["source_text"], meta=d.get("meta", {}) ) for d in data] provider = env("EXTRACTOR_PROVIDER", provider).lower().strip() if provider == "grok": res = await grok_extract(turn_text, turn_num) elif provider == "groq": res = await groq_extract(turn_text, turn_num) else: res = fallback_extract(turn_text, turn_num) # Cache the result (serialize MemoryEntry objects) serialized = [{ "type": m.type.value, "key": m.key, "value": m.value, "source_turn": m.source_turn, "confidence": m.confidence, "source_text": m.source_text, "meta": m.meta } for m in res] _EXT_CACHE[cache_key] = serialized if turn_num % 10 == 0: # Save periodically save_cache() return res