shifaa_api / app /data_loader.py
MossaabDev's picture
Update app/data_loader.py
16a879c verified
raw
history blame
1.64 kB
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
from app.utils import remove_numbers
from app.qdrant_client import client
from qdrant_client.http import models
from pympler import asizeof
print("Loading model and data...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model = SentenceTransformer("app/my_finetuned_modelV2", device=device)
print("model size : ", asizeof.asizeof(model))
df = pd.read_csv("app/data/cleaned_fileV2.csv")
df['answer'] = df['answer'].apply(remove_numbers)
ayat = list(set(df['answer']))
print(f"Total unique ayat loaded: {asizeof.asizeof(ayat)}")
print("✅ Model and embeddings ready.")
# --- Check if collection exists ---
collections = [c.name for c in client.get_collections().collections]
if "ayahs_collection" not in collections:
print("Creating Qdrant collection and uploading embeddings...")
embeddings = model.encode(ayat, convert_to_tensor=False).tolist()
client.recreate_collection(
collection_name="ayahs_collection",
vectors_config=models.VectorParams(
size=len(embeddings[0]),
distance=models.Distance.COSINE
),
)
points = [
models.PointStruct(
id=idx,
vector=emb,
payload={"text": ayah}
)
for idx, (emb, ayah) in enumerate(zip(embeddings, ayat))
]
client.upsert(collection_name="ayahs_collection", points=points)
print("✅ Embeddings uploaded to Qdrant.")
else:
print("✅ Collection already exists, skipping upload.")
# Load embeddings from Qdrant