Pronto / core /agents /model_doc_sync.py
rivapereira123's picture
Update core/agents/model_doc_sync.py
b7844f9 verified
import os
import requests
import pinecone
from openai import OpenAI
from sentence_transformers import SentenceTransformer
import json
# Initialize with error handling
try:
# Load API keys
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
# Clients
import openai
openai.api_key = OPENAI_API_KEY
pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
EMBEDDING_DIM = 1024
EMBEDDER_MODEL = "thenlper/gte-large"
embedder = SentenceTransformer(EMBEDDER_MODEL)
# Constants
PINECONE_INDEX = "pronto"
except Exception as e:
raise RuntimeError(f"Initialization failed: {str(e)}")
# 1. Search Model Docs (Robust Version)
def search_model_docs(model_name: str):
"""Fetches model specs with retry logic"""
for _ in range(3): # Retry up to 3 times
try:
response = requests.post(
"https://api.tavily.com/search",
headers={"Authorization": f"Bearer {TAVILY_API_KEY}"},
json={"query": f"{model_name} LLM pricing 2025", "num_results": 3},
timeout=10
)
response.raise_for_status()
return response.json().get("results", [])
except requests.exceptions.RequestException as e:
print(f"⚠️ Retrying... ({str(e)})")
continue
raise ConnectionError("Tavily API unavailable after 3 retries")
import openai
import json
def extract_llm_text(response) -> str:
"""Safely extract content from OpenAI chat or completion response"""
try:
# Chat-based response (e.g. from gpt-4, gpt-3.5-turbo)
return response.choices[0].message["content"]
except (AttributeError, KeyError, TypeError):
# Fallback: Completion-based response (e.g. davinci)
return response.choices[0].text
def summarize_pricing(docs: list):
"""
Uses OpenAI to extract model pricing info from scraped documents.
Expects a list of Tavily search result documents.
"""
try:
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{
"role": "system",
"content": (
"You are a financial data extraction assistant. "
"Extract model pricing info as a JSON object with the following fields: "
"`model`, `input_cost`, `output_cost`, `max_tokens`, `source`."
)
},
{
"role": "user",
"content": "\n\n".join([d.get("content", "") for d in docs])
}
],
temperature=0,
max_tokens=500
)
raw_content = extract_llm_text(response)
print(f"[DEBUG] Raw LLM Response:\n{raw_content}\n")
result = json.loads(raw_content)
# Normalize format
if isinstance(result, list) and len(result) > 0:
return result[0]
elif isinstance(result, dict):
return result
else:
raise ValueError("Unexpected format from OpenAI output")
except Exception as e:
print(f"[ERROR] summarize_pricing failed: {str(e)}")
return None
# 3. Pinecone Storage (Fixed Dimension)
def init_pinecone():
"""Ensures index exists with correct dimensions"""
if PINECONE_INDEX not in pc.list_indexes().names():
pc.create_index(
name=PINECONE_INDEX,
dimension=EMBEDDING_DIM, # Critical fix!
metric="cosine",
spec=pinecone.ServerlessSpec(cloud="aws", region="us-east-1")
)
return pc.Index(PINECONE_INDEX)
# Add to imports
from typing import Dict, Any
import logging
logging.basicConfig(level=logging.INFO)
def store_pricing(model: str, data: Dict[str, Any]):
"""Improved storage with validation"""
index = init_pinecone()
model_id = model.strip().lower()
# 1. Validate data structure
required_fields = ["model", "input_cost", "output_cost"]
if not all(field in data for field in required_fields):
raise ValueError(f"Missing required fields in {data}")
# 2. Generate and verify embedding
embedding = embedder.encode(model_id).tolist()
if len(embedding) != EMBEDDING_DIM:
logging.warning(f"Padding embedding from {len(embedding)} to {EMBEDDING_DIM}")
embedding += [0] * (EMBEDDING_DIM - len(embedding))
# 3. Enhanced upsert
try:
response = index.upsert(
vectors=[(model_id, embedding, data)],
namespace="pricing"
)
logging.info(f"Stored {model_id}: {response}")
return True
except Exception as e:
logging.error(f"Pinecone error: {str(e)}")
return False
def sync_pricing(model_names: list):
"""Fetch, summarize, and store model pricing info"""
from typing import Dict, Any
import logging
logging.basicConfig(level=logging.INFO)
index = init_pinecone()
for model in model_names:
try:
model = model.strip()
logging.info(f"\n🔄 Syncing {model}...")
docs = search_model_docs(model)
if not docs:
logging.warning(f"No docs found for {model}")
continue
pricing = summarize_pricing(docs)
if not pricing:
logging.error(f"No pricing extracted for {model}")
continue
if store_pricing(model, pricing):
stored_data = index.fetch(ids=[model.lower()], namespace="pricing")
if stored_data.vectors:
logging.info(f"✅ Verified storage for {model}")
else:
logging.error(f"❌ Storage failed for {model}")
except Exception as e:
logging.error(f"Critical error with {model}: {str(e)}")