Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # Loading | |
| import os | |
| from os import makedirs,getcwd | |
| from os.path import join,exists,dirname | |
| from datasets import load_dataset | |
| import torch | |
| from tqdm import tqdm | |
| from sentence_transformers import SentenceTransformer | |
| import uuid | |
| from qdrant_client import models, QdrantClient | |
| from itertools import islice | |
| app = FastAPI() | |
| FILEPATH_PATTERN = "structured_data_doc.parquet" | |
| NUM_PROC = os.cpu_count() | |
| parent_path = dirname(getcwd()) | |
| temp_path = join(parent_path,'temp') | |
| if not exists(temp_path ): | |
| makedirs(temp_path ) | |
| # Determine device based on GPU availability | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Load the desired model | |
| model = SentenceTransformer( | |
| 'sentence-transformers/all-MiniLM-L6-v2', | |
| device=device | |
| ) | |
| # Create function to upsert embeddings in batches | |
| def batched(iterable, n): | |
| iterator = iter(iterable) | |
| while batch := list(islice(iterator, n)): | |
| yield batch | |
| batch_size = 100 | |
| # Create an in-memory Qdrant instance | |
| client2 = QdrantClient(":memory:") | |
| # Create a Qdrant collection for the embeddings | |
| client2.create_collection( | |
| collection_name="law", | |
| vectors_config=models.VectorParams( | |
| size=model.get_sentence_embedding_dimension(), | |
| distance=models.Distance.COSINE, | |
| ), | |
| ) | |
| # Create function to generate embeddings (in batches) for a given dataset split | |
| def generate_embeddings(dataset, batch_size=32): | |
| embeddings = [] | |
| with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar: | |
| for i in range(0, len(dataset), batch_size): | |
| batch_sentences = dataset['content'][i:i+batch_size] | |
| batch_embeddings = model.encode(batch_sentences) | |
| embeddings.extend(batch_embeddings) | |
| pbar.update(len(batch_sentences)) | |
| return embeddings | |
| async def create_upload_file(file: UploadFile = File(...)): | |
| # Here you can save the file and do other operations as needed | |
| full_dataset = load_dataset("parquet", | |
| data_files=FILEPATH_PATTERN, | |
| split="train", | |
| cache_path=temp_path, | |
| keep_in_memory=True, | |
| num_proc=NUM_PROC*2) | |
| # Generate and append embeddings to the train split | |
| law_embeddings = generate_embeddings(full_dataset) | |
| full_dataset= full_dataset.add_column("embeddings", law_embeddings) | |
| if not 'uuid' in full_dataset.column_names: | |
| full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))]) | |
| # Upsert the embeddings in batches | |
| for batch in batched(full_dataset, batch_size): | |
| ids = [point.pop("uuid") for point in batch] | |
| vectors = [point.pop("embeddings") for point in batch] | |
| client2.upsert( | |
| collection_name="law", | |
| points=models.Batch( | |
| ids=ids, | |
| vectors=vectors, | |
| payloads=batch, | |
| ), | |
| ) | |
| return {"filename": file.filename, "message": "Done"} | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def search(prompt: str): | |
| # Let's see what senators are saying about immigration policy | |
| hits = client2.search( | |
| collection_name="law", | |
| query_vector=model.encode(prompt).tolist(), | |
| limit=5 | |
| ) | |
| for hit in hits: | |
| print(hit.payload, "score:", hit.score) | |
| return {'detail': 'hit.payload', 'score:': hit.score} | |
| def api_home(): | |
| return {'detail': 'Welcome to FastAPI Qdrant importer!'} |