Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import time | |
| from typing import List | |
| import requests | |
| import uuid | |
| import json | |
| from app.core.database import upsert_points | |
| # Configure Gemini | |
| GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") | |
| if not GOOGLE_API_KEY: | |
| raise ValueError("GEMINI_API_KEY must be set in .env") | |
| # Using Gemini 1.5 Flash for Embeddings (REST API) | |
| # Official Endpoint: https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent | |
| EMBEDDING_API_URL = f"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={GOOGLE_API_KEY}" | |
| def get_embedding(text: str) -> List[float]: | |
| """ | |
| Generates embedding using Gemini REST API with retry logic for rate limits. | |
| """ | |
| payload = { | |
| "model": "models/text-embedding-004", | |
| "content": { | |
| "parts": [{"text": text}] | |
| } | |
| } | |
| # Retry logic with exponential backoff | |
| max_retries = 3 | |
| retry_delay = 1 | |
| for attempt in range(max_retries): | |
| try: | |
| response = requests.post(EMBEDDING_API_URL, json=payload, headers={"Content-Type": "application/json"}, timeout=30) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return data["embedding"]["values"] | |
| elif response.status_code == 429: | |
| # Rate limit - retry with backoff | |
| if attempt < max_retries - 1: | |
| print(f"Embedding rate limit. Retrying in {retry_delay}s...") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 | |
| continue | |
| else: | |
| raise Exception("Rate limit exceeded after retries") | |
| else: | |
| print(f"Embedding Error ({response.status_code}): {response.text}") | |
| raise Exception(f"Failed to generate embedding: {response.status_code}") | |
| except requests.exceptions.Timeout: | |
| if attempt < max_retries - 1: | |
| print(f"Embedding timeout. Retrying in {retry_delay}s...") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 | |
| continue | |
| else: | |
| raise Exception("Embedding request timed out after retries") | |
| except Exception as e: | |
| if attempt < max_retries - 1 and "rate limit" in str(e).lower(): | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 | |
| continue | |
| raise | |
| def load_markdown_files(docs_path: str) -> List[dict]: | |
| files = [] | |
| search_path = os.path.join(docs_path, "**/*.md") | |
| for filepath in glob.glob(search_path, recursive=True): | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| filename = os.path.basename(filepath) | |
| files.append({ | |
| "content": content, | |
| "source": filename, | |
| "path": filepath | |
| }) | |
| return files | |
| def chunk_text(text: str, chunk_size: int = 2000, overlap: int = 100) -> List[str]: | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| chunk = text[start:end] | |
| chunks.append(chunk) | |
| start += (chunk_size - overlap) | |
| return chunks | |
| def process_and_index_documents(docs_path: str): | |
| print(f"Loading documents from: {docs_path}") | |
| documents = load_markdown_files(docs_path) | |
| print(f"Found {len(documents)} markdown files.") | |
| points_batch = [] | |
| for doc in documents: | |
| chunks = chunk_text(doc["content"]) | |
| for i, chunk in enumerate(chunks): | |
| try: | |
| embedding = get_embedding(chunk) | |
| # Create Point Structure for Qdrant REST API | |
| point = { | |
| "id": str(uuid.uuid4()), | |
| "vector": embedding, | |
| "payload": { | |
| "text": chunk, | |
| "source": doc["source"], | |
| "path": doc["path"], | |
| "chunk_id": i | |
| } | |
| } | |
| points_batch.append(point) | |
| # Upload in batches of 50 to avoid big payloads | |
| if len(points_batch) >= 50: | |
| upsert_points(points_batch) | |
| points_batch = [] | |
| print(".", end="", flush=True) | |
| except Exception as e: | |
| print(f"Error processing chunk in {doc['source']}: {e}") | |
| # Upload remaining | |
| if points_batch: | |
| upsert_points(points_batch) | |
| print("\nUpload complete!") | |
| return {"status": "success"} | |