Th3Nic3Guy's picture
update
1507263
# App Main file
# from travel import ui as travel_ui
import os
import uuid
from typing import List, Sequence
import warnings
from langchain_community.document_loaders import CSVLoader
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
import gradio as gr
from event_ui import ui as events_ui
from fashion import ui as fashion_ui
from travel_v2 import ui as travel_ui
warnings.filterwarnings("ignore")
MODELS_ENABLED = [
"gemini-2.0-flash",
"gemini-1.5-flash",
]
# Use a persistent local path for Qdrant data
QDRANT_PATH = './qdrant_data/'+uuid.uuid4().hex
qdrant_client = QdrantClient(path=QDRANT_PATH)
model = SentenceTransformer("all-mpnet-base-v2")
# Collection name for storing document chunks
COLLECTION_NAME = 'tmp_collection'
# Function to create the Qdrant collection if it doesn't exist
def create_collection(collection_name: str, vector_size: int, ):
"""
Creates a Qdrant collection with the specified name, vector size, and
distance metric.
Args:
collection_name (str): The name of the collection to create.
vector_size (int): The size of the vectors to be stored in the
collection.
distance (str, optional): The distance metric to use for vector
comparison.
Defaults to "Cosine".
Other options: "Dot", "Euclid"
"""
distance_m = models.Distance.COSINE
try:
qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(
size=vector_size, distance=distance_m),
)
print(f"Collection '{collection_name}' created successfully.")
except Exception as e: # pylint: disable=broad-except
print(f"Error creating collection '{collection_name}': {e}")
# Function to chunk the text into smaller parts
def chunk_text(
text: str,
chunk_size: int = 500,
chunk_overlap: int = 50
) -> Sequence[Document]:
"""
Chunks a large text into smaller documents.
Args:
text (str): The text to chunk.
chunk_size (int, optional): The maximum size of each chunk.
Defaults to 500.
chunk_overlap (int, optional): The amount of overlap between chunks.
Defaults to 50.
Returns:
List[Document]: A list of Document objects, each representing a chunk.
"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", " ", ""],
)
chunks = text_splitter.create_documents([text])
return chunks
# Function to embed the text chunks using the Sentence Transformer model
def embed_chunks(chunks: List[Document]) -> List[List[float]]:
"""
Embeds a list of text chunks using the Sentence Transformer model.
Args:
chunks (List[Document]): A list of Document objects, each representing
a chunk.
Returns:
List[List[float]]: A list of embeddings for each chunk.
"""
text_chunks = [chunk.page_content for chunk in chunks]
embeddings = model.encode(text_chunks).tolist()
return embeddings
# Function to upload chunks to Qdrant
def upload_to_qdrant(
chunks: List[Document],
embeddings: List[List[float]],
collection_name: str
):
"""
Uploads text chunks and their embeddings to Qdrant.
Args:
chunks (List[Document]): A list of Document objects.
embeddings (List[List[float]]): A list of embeddings for each chunk.
collection_name (str): The name of the Qdrant collection to upload to.
"""
points = []
for i, chunk in enumerate(chunks):
points.append(
models.PointStruct(
id=uuid.uuid4().hex,
vector=embeddings[i],
payload={
"text": chunk.page_content,
"metadata": chunk.metadata,
},
)
)
qdrant_client.upsert(collection_name=collection_name, points=points)
def parse_document(file_path: str) -> str:
"""
Parses a document and returns the text content.
Args:
file_path (str): The path to the document file.
Returns:
str: The text content of the document.
"""
with open(file_path, "r", encoding='utf-8') as file:
text = file.read()
return text
def process_file(file_obj: gr.File) -> str:
"""
Processes an uploaded file, parses it, chunks it, embeds the chunks, and
uploads to Qdrant.
Args:
file_obj (gr.File): The uploaded file object.
Returns:
str: A message indicating the success or failure of the process.
"""
try:
file_path = file_obj.name
# create a collection if not exists
if not qdrant_client.collection_exists(COLLECTION_NAME):
create_collection(
collection_name=COLLECTION_NAME,
vector_size=768,
)
# Parse
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", " ", ""],
)
chunks = CSVLoader(
file_path=file_path
).load_and_split(
text_splitter
)
embeddings = embed_chunks(chunks)
upload_to_qdrant(chunks, embeddings, COLLECTION_NAME)
print(len(chunks), "chunks uploaded to Qdrant.")
return f"File '{os.path.basename(file_path)}' processed!"
except Exception as e: # pylint: disable=broad-except
return f"Error processing file: {e}"
with gr.Blocks(
title='Planner Demos',
# theme=gr.themes.Origin(),
) as demo:
gr.Markdown("""# Sample GenAI Demos
> Note: get ypur gemini API key from:
> https://ai.google.dev/gemini-api/docs/api-key
""")
with gr.Accordion(label='Model Config') as config:
api_key = gr.Text(
placeholder='Gemini API key',
label='Gemini API Key',
interactive=True,
value=os.getenv("GEMINI_API_KEY"),
visible=False
)
gemini_model_name = gr.Dropdown(
label='Gemini Model',
value=MODELS_ENABLED[0],
choices=MODELS_ENABLED,
)
with gr.Accordion(
label='Upload Personal Dataset',
open=False
) as dataset:
dataset_upload = gr.File(
label='Upload Personal Dataset',
interactive=True,
)
upload_button = gr.Button("Process and Upload")
output = gr.Textbox(label="Status")
upload_button.click( # pylint: disable=no-member
process_file,
inputs=dataset_upload,
outputs=output
)
with gr.Accordion(label='Planners') as planners:
with gr.Tab(label='Travel Planner'):
travel_ui(api_key, gemini_model_name)
with gr.Tab(label='Fashion Advisor'):
fashion_ui(api_key, gemini_model_name)
with gr.Tab(label='Beauty Advisor'):
events_ui(api_key, gemini_model_name)
demo.launch(debug=True, server_port=int(os.getenv("PORT", "7860")))