matching / index_objects.py
Calcifer0323's picture
Fix: Update to RoSBERTa model (1024 dims), remove half precision, increase timeout
93cd57d
import requests
import json
import re
import sqlite3
import time
# Читаем SQL файл
with open('pars_samolet.sql', 'r', encoding='utf-8') as f:
sql_content = f.read()
# Извлекаем CREATE TABLE
create_match = re.search(r'CREATE TABLE[^;]+;', sql_content, re.DOTALL)
if not create_match:
raise ValueError("CREATE TABLE not found")
create_stmt = create_match.group(0)
# Извлекаем все INSERT
insert_pattern = r'INSERT INTO mytable\([^)]+\) VALUES\s*\([^;]+\);'
inserts = re.findall(insert_pattern, sql_content, re.DOTALL)
# Создаем временную базу для парсинга
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
cursor.execute(create_stmt)
# Функция для экранирования SQL-строк
def escape_sql_value(value):
if value is None:
return 'NULL'
# Заменяем одинарные кавычки на две одинарные кавычки
return str(value).replace("'", "''")
# Выполняем INSERT с экранированием
for insert in inserts:
try:
cursor.execute(insert)
except sqlite3.OperationalError as e:
# Если есть ошибка, пытаемся очистить строку от проблемных символов
# Находим VALUES часть
match = re.search(r'VALUES\s*\((.+)\);', insert, re.DOTALL)
if match:
values_str = match.group(1)
# Разбиваем значения по запятым, учитывая строки в кавычках
values = []
current = ''
in_string = False
for char in values_str:
if char == "'" and (len(current) == 0 or current[-1] != '\\'):
in_string = not in_string
current += char
if not in_string and char == ',':
values.append(current[:-1].strip())
current = ''
if current:
values.append(current.strip())
# Экранируем каждое строковое значение
cleaned_values = []
for val in values:
if val.upper() == 'NULL':
cleaned_values.append('NULL')
elif val.startswith("'") and val.endswith("'"):
# Это строковое значение
inner = val[1:-1]
# Удаляем лишние переводы строк и табуляции
inner = inner.replace('\n', ' ').replace('\r', ' ')
inner = ' '.join(inner.split()) # Удаляем лишние пробелы
inner = escape_sql_value(inner)
cleaned_values.append(f"'{inner}'")
else:
# Числовое или другое значение
cleaned_values.append(val)
# Собираем новый INSERT
insert_start = insert[:match.start(1)]
insert_end = insert[match.end(1):]
new_insert = insert_start + ', '.join(cleaned_values) + insert_end
try:
cursor.execute(new_insert)
print(f"Fixed problematic INSERT")
except Exception as e2:
print(f"Still failed to execute INSERT: {e2}")
# Пропускаем проблемную запись
continue
else:
print(f"Could not parse INSERT: {insert[:100]}...")
continue
# Получаем данные
cursor.execute('SELECT * FROM mytable')
rows = cursor.fetchall()
# Получаем имена колонок
cursor.execute("PRAGMA table_info(mytable)")
columns = [col[1] for col in cursor.fetchall()]
# Создаем список словарей
objects = [dict(zip(columns, row)) for row in rows]
conn.close()
print(f"Total objects parsed: {len(objects)}")
# Разделяем на батчи по 10, берем первые 4
batch_size = 50
batches = [objects[i:i + batch_size] for i in range(0, len(objects), batch_size)][:4]
# URL для API
url = 'https://calcifer0323-matching.hf.space/batch'
responses = []
for i, batch in enumerate(batches):
print(f"Sending batch {i+1} with {len(batch)} objects")
# Преобразуем объекты в items для /batch
items = []
for obj in batch:
item = {
"entity_id": str(obj["property_id"]),
"title": str(obj.get("title", "")),
"description": str(obj.get("description", "")),
"price": float(obj.get("price", 0)) if obj.get("price") else None,
"rooms": float(obj.get("rooms", 0)) if obj.get("rooms") else None,
"area": float(obj.get("area", 0)) if obj.get("area") else None,
"address": str(obj.get("address", "")),
"district": str(obj.get("city", ""))
}
items.append(item)
payload = {"items": items}
try:
response = requests.post(url, json=payload, timeout=300)
if response.status_code == 200:
data = response.json()
responses.append(data)
successful = data.get('successful', 0)
total = data.get('total', 0)
print(f"Batch {i+1} successful, embedded {successful}/{total}")
else:
print(f"Batch {i+1} failed: {response.status_code} - {response.text}")
except Exception as e:
print(f"Error sending batch {i+1}: {e}")
# Задержка между батчами
if i < len(batches) - 1:
print("Waiting 10 seconds before next batch...")
time.sleep(10)
# Сохраняем результаты в SQL файл
with open('indexed_objects.sql', 'w', encoding='utf-8') as f:
f.write("CREATE TABLE IF NOT EXISTS indexed_objects (\n")
f.write(" property_id VARCHAR(36) PRIMARY KEY,\n")
f.write(" embedding JSON\n")
f.write(");\n\n")
f.write("DELETE FROM indexed_objects;\n\n")
for resp in responses:
for result in resp.get("results", []):
if result.get("success"):
property_id = result["entity_id"]
embedding = json.dumps(result["embedding"])
# Экранируем для SQL
embedding_escaped = embedding.replace("'", "''")
f.write(f"INSERT INTO indexed_objects (property_id, embedding) VALUES ('{property_id}', '{embedding_escaped}');\n")
print(f"Indexing complete. Results saved to indexed_objects.sql")
print(f"Total batches processed: {len(responses)}")