Spaces:
Sleeping
Sleeping
File size: 4,914 Bytes
1c29d49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import glob
import time
from typing import List
import requests
import uuid
import json
from app.core.database import upsert_points
# Configure Gemini
GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("GEMINI_API_KEY must be set in .env")
# Using Gemini 1.5 Flash for Embeddings (REST API)
# Official Endpoint: https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent
EMBEDDING_API_URL = f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={GOOGLE_API_KEY}"
def get_embedding(text: str) -> List[float]:
"""
Generates embedding using Gemini REST API with retry logic for rate limits.
"""
payload = {
"model": "models/text-embedding-004",
"content": {
"parts": [{"text": text}]
}
}
# Retry logic with exponential backoff
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
response = requests.post(EMBEDDING_API_URL, json=payload, headers={"Content-Type": "application/json"}, timeout=30)
if response.status_code == 200:
data = response.json()
return data["embedding"]["values"]
elif response.status_code == 429:
# Rate limit - retry with backoff
if attempt < max_retries - 1:
print(f"Embedding rate limit. Retrying in {retry_delay}s...")
time.sleep(retry_delay)
retry_delay *= 2
continue
else:
raise Exception("Rate limit exceeded after retries")
else:
print(f"Embedding Error ({response.status_code}): {response.text}")
raise Exception(f"Failed to generate embedding: {response.status_code}")
except requests.exceptions.Timeout:
if attempt < max_retries - 1:
print(f"Embedding timeout. Retrying in {retry_delay}s...")
time.sleep(retry_delay)
retry_delay *= 2
continue
else:
raise Exception("Embedding request timed out after retries")
except Exception as e:
if attempt < max_retries - 1 and "rate limit" in str(e).lower():
time.sleep(retry_delay)
retry_delay *= 2
continue
raise
def load_markdown_files(docs_path: str) -> List[dict]:
files = []
search_path = os.path.join(docs_path, "**/*.md")
for filepath in glob.glob(search_path, recursive=True):
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
filename = os.path.basename(filepath)
files.append({
"content": content,
"source": filename,
"path": filepath
})
return files
def chunk_text(text: str, chunk_size: int = 2000, overlap: int = 100) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
chunks.append(chunk)
start += (chunk_size - overlap)
return chunks
def process_and_index_documents(docs_path: str):
print(f"Loading documents from: {docs_path}")
documents = load_markdown_files(docs_path)
print(f"Found {len(documents)} markdown files.")
points_batch = []
for doc in documents:
chunks = chunk_text(doc["content"])
for i, chunk in enumerate(chunks):
try:
embedding = get_embedding(chunk)
# Create Point Structure for Qdrant REST API
point = {
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
"text": chunk,
"source": doc["source"],
"path": doc["path"],
"chunk_id": i
}
}
points_batch.append(point)
# Upload in batches of 50 to avoid big payloads
if len(points_batch) >= 50:
upsert_points(points_batch)
points_batch = []
print(".", end="", flush=True)
except Exception as e:
print(f"Error processing chunk in {doc['source']}: {e}")
# Upload remaining
if points_batch:
upsert_points(points_batch)
print("\nUpload complete!")
return {"status": "success"}
|