Spaces:
Sleeping
Sleeping
Commit
·
b02093e
1
Parent(s):
88967ed
Fixed naking convention
Browse files- .github/workflows/main.yaml +7 -7
- __pycache__/logger_config.cpython-312.pyc +0 -0
- api/__pycache__/main.cpython-312.pyc +0 -0
- api/main.py +18 -1
- logger_config.py +40 -0
- logs/access.log +14 -0
- logs/app.log +1 -0
- logs/errors.log +0 -0
- src/genai/brainstroming_agent/utils/tools.py +4 -4
- src/genai/ideation_agent/utils/tools.py +7 -5
- src/genai/orchestration_agent/utils/__pycache__/utils.cpython-312.pyc +0 -0
- src/genai/orchestration_agent/utils/tools.py +9 -9
- src/genai/orchestration_agent/utils/utils.py +3 -1
- src/genai/utils/load_embeddings.py +15 -3
.github/workflows/main.yaml
CHANGED
|
@@ -18,33 +18,33 @@ jobs:
|
|
| 18 |
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
| 19 |
|
| 20 |
steps:
|
| 21 |
-
- name:
|
| 22 |
uses: actions/checkout@v3
|
| 23 |
|
| 24 |
-
- name:
|
| 25 |
uses: actions/setup-python@v4
|
| 26 |
with:
|
| 27 |
python-version: '3.13'
|
| 28 |
|
| 29 |
-
- name:
|
| 30 |
run: |
|
| 31 |
python -m pip install --upgrade pip
|
| 32 |
pip install -r requirements.txt
|
| 33 |
pip install pytest
|
| 34 |
|
| 35 |
-
- name:
|
| 36 |
run: pytest
|
| 37 |
|
| 38 |
-
- name:
|
| 39 |
uses: docker/setup-buildx-action@v3
|
| 40 |
|
| 41 |
-
- name:
|
| 42 |
uses: docker/login-action@v3
|
| 43 |
with:
|
| 44 |
username: ${{ secrets.DOCKER_USERNAME }}
|
| 45 |
password: ${{ secrets.DOCKER_PASSWORD }}
|
| 46 |
|
| 47 |
-
- name:
|
| 48 |
uses: docker/build-push-action@v5
|
| 49 |
with:
|
| 50 |
context: .
|
|
|
|
| 18 |
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
| 19 |
|
| 20 |
steps:
|
| 21 |
+
- name: Checkout code
|
| 22 |
uses: actions/checkout@v3
|
| 23 |
|
| 24 |
+
- name: Set up Python
|
| 25 |
uses: actions/setup-python@v4
|
| 26 |
with:
|
| 27 |
python-version: '3.13'
|
| 28 |
|
| 29 |
+
- name: Install dependencies
|
| 30 |
run: |
|
| 31 |
python -m pip install --upgrade pip
|
| 32 |
pip install -r requirements.txt
|
| 33 |
pip install pytest
|
| 34 |
|
| 35 |
+
- name: Run tests
|
| 36 |
run: pytest
|
| 37 |
|
| 38 |
+
- name: Set up Docker Buildx
|
| 39 |
uses: docker/setup-buildx-action@v3
|
| 40 |
|
| 41 |
+
- name: Log in to Docker Hub
|
| 42 |
uses: docker/login-action@v3
|
| 43 |
with:
|
| 44 |
username: ${{ secrets.DOCKER_USERNAME }}
|
| 45 |
password: ${{ secrets.DOCKER_PASSWORD }}
|
| 46 |
|
| 47 |
+
- name: Build and Push Docker image
|
| 48 |
uses: docker/build-push-action@v5
|
| 49 |
with:
|
| 50 |
context: .
|
__pycache__/logger_config.cpython-312.pyc
ADDED
|
Binary file (2.18 kB). View file
|
|
|
api/__pycache__/main.cpython-312.pyc
CHANGED
|
Binary files a/api/__pycache__/main.cpython-312.pyc and b/api/__pycache__/main.cpython-312.pyc differ
|
|
|
api/main.py
CHANGED
|
@@ -1,9 +1,26 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
|
|
|
|
|
|
| 2 |
from .routers import orchestration, context_analysis, ideation , human_idea_refining , brainstorm , generate_final_story , generate_image, show_analytics
|
| 3 |
|
|
|
|
| 4 |
app = FastAPI()
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
@app.get("/")
|
| 8 |
async def root():
|
| 9 |
return {'response':'Hello'}
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request
|
| 2 |
+
from logger_config import setup_loggers
|
| 3 |
+
import logging
|
| 4 |
from .routers import orchestration, context_analysis, ideation , human_idea_refining , brainstorm , generate_final_story , generate_image, show_analytics
|
| 5 |
|
| 6 |
+
setup_loggers()
|
| 7 |
app = FastAPI()
|
| 8 |
|
| 9 |
|
| 10 |
+
# Get loggers
|
| 11 |
+
app_logger = logging.getLogger("app_logger")
|
| 12 |
+
error_logger = logging.getLogger("error_logger")
|
| 13 |
+
access_logger = logging.getLogger("access_logger")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@app.middleware("http")
|
| 17 |
+
async def log_requests(request: Request, call_next):
|
| 18 |
+
access_logger.info(f"Request: {request.method} {request.url}")
|
| 19 |
+
response = await call_next(request)
|
| 20 |
+
access_logger.info(f"Response status: {response.status_code}")
|
| 21 |
+
return response
|
| 22 |
+
|
| 23 |
+
|
| 24 |
@app.get("/")
|
| 25 |
async def root():
|
| 26 |
return {'response':'Hello'}
|
logger_config.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# logger_config.py
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def setup_loggers():
|
| 6 |
+
os.makedirs("logs", exist_ok=True)
|
| 7 |
+
|
| 8 |
+
# === Format ===
|
| 9 |
+
formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s")
|
| 10 |
+
|
| 11 |
+
# === App Logger ===
|
| 12 |
+
app_logger = logging.getLogger("app_logger")
|
| 13 |
+
app_handler = logging.FileHandler("logs/app.log")
|
| 14 |
+
app_handler.setLevel(logging.INFO)
|
| 15 |
+
app_handler.setFormatter(formatter)
|
| 16 |
+
app_logger.setLevel(logging.INFO)
|
| 17 |
+
app_logger.addHandler(app_handler)
|
| 18 |
+
|
| 19 |
+
# === Error Logger ===
|
| 20 |
+
error_logger = logging.getLogger("error_logger")
|
| 21 |
+
error_handler = logging.FileHandler("logs/errors.log")
|
| 22 |
+
error_handler.setLevel(logging.ERROR)
|
| 23 |
+
error_handler.setFormatter(formatter)
|
| 24 |
+
error_logger.setLevel(logging.ERROR)
|
| 25 |
+
error_logger.addHandler(error_handler)
|
| 26 |
+
|
| 27 |
+
# === Access Logger === (optional for request logs)
|
| 28 |
+
access_logger = logging.getLogger("access_logger")
|
| 29 |
+
access_handler = logging.FileHandler("logs/access.log")
|
| 30 |
+
access_handler.setLevel(logging.INFO)
|
| 31 |
+
access_handler.setFormatter(formatter)
|
| 32 |
+
access_logger.setLevel(logging.INFO)
|
| 33 |
+
access_logger.addHandler(access_handler)
|
| 34 |
+
|
| 35 |
+
# Optional: also log to console
|
| 36 |
+
console_handler = logging.StreamHandler()
|
| 37 |
+
console_handler.setFormatter(formatter)
|
| 38 |
+
app_logger.addHandler(console_handler)
|
| 39 |
+
error_logger.addHandler(console_handler)
|
| 40 |
+
access_logger.addHandler(console_handler)
|
logs/access.log
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-07-30 16:38:13,871 | INFO | access_logger | Request: GET http://127.0.0.1:8000/docs
|
| 2 |
+
2025-07-30 16:38:13,872 | INFO | access_logger | Response status: 200
|
| 3 |
+
2025-07-30 16:38:14,821 | INFO | access_logger | Request: GET http://127.0.0.1:8000/openapi.json
|
| 4 |
+
2025-07-30 16:38:14,833 | INFO | access_logger | Response status: 200
|
| 5 |
+
2025-07-30 16:38:28,560 | INFO | access_logger | Request: GET http://127.0.0.1:8000/
|
| 6 |
+
2025-07-30 16:38:28,561 | INFO | access_logger | Response status: 200
|
| 7 |
+
2025-07-30 16:38:56,487 | INFO | access_logger | Request: POST http://127.0.0.1:8000/api/human-idea-refining
|
| 8 |
+
2025-07-30 16:38:58,518 | INFO | access_logger | Response status: 200
|
| 9 |
+
2025-07-30 16:39:25,313 | INFO | access_logger | Request: POST http://127.0.0.1:8000/api/brainstorm
|
| 10 |
+
2025-07-30 16:39:36,441 | INFO | access_logger | Response status: 200
|
| 11 |
+
2025-07-30 16:40:01,761 | INFO | access_logger | Request: POST http://127.0.0.1:8000/api/brainstorm
|
| 12 |
+
2025-07-30 16:40:01,763 | INFO | access_logger | Response status: 422
|
| 13 |
+
2025-07-30 16:53:21,510 | INFO | access_logger | Request: POST http://127.0.0.1:8000/api/show-analytics
|
| 14 |
+
2025-07-30 16:53:23,607 | INFO | access_logger | Response status: 200
|
logs/app.log
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
2025-07-30 16:53:23,605 | INFO | app_logger | Showing Analytics of the influencers after context analysis.
|
logs/errors.log
ADDED
|
File without changes
|
src/genai/brainstroming_agent/utils/tools.py
CHANGED
|
@@ -5,7 +5,7 @@ from src.genai.utils.models_loader import embedding_model , llm
|
|
| 5 |
import numpy as np
|
| 6 |
import faiss
|
| 7 |
import tiktoken
|
| 8 |
-
from src.genai.utils.load_embeddings import
|
| 9 |
from src.genai.utils.utils import clean_text
|
| 10 |
|
| 11 |
def retrieve_tool(video_topic):
|
|
@@ -16,8 +16,8 @@ def retrieve_tool(video_topic):
|
|
| 16 |
query_embedding = np.array(embedding_model.embed_query(str(video_topic))).reshape(1, -1).astype('float32')
|
| 17 |
faiss.normalize_L2(query_embedding)
|
| 18 |
|
| 19 |
-
top_k = len(
|
| 20 |
-
distances, indices =
|
| 21 |
|
| 22 |
similarity_threshold = 0.35
|
| 23 |
selected = [(idx, sim) for idx, sim in zip(indices[0], distances[0]) if sim >= similarity_threshold]
|
|
@@ -28,7 +28,7 @@ def retrieve_tool(video_topic):
|
|
| 28 |
# === Format results ===
|
| 29 |
outer_list = []
|
| 30 |
for rank, (idx, sim) in enumerate(selected, 1):
|
| 31 |
-
row =
|
| 32 |
res = {
|
| 33 |
'rank': rank,
|
| 34 |
'username': row['username'],
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import faiss
|
| 7 |
import tiktoken
|
| 8 |
+
from src.genai.utils.load_embeddings import caption_index , caption_df
|
| 9 |
from src.genai.utils.utils import clean_text
|
| 10 |
|
| 11 |
def retrieve_tool(video_topic):
|
|
|
|
| 16 |
query_embedding = np.array(embedding_model.embed_query(str(video_topic))).reshape(1, -1).astype('float32')
|
| 17 |
faiss.normalize_L2(query_embedding)
|
| 18 |
|
| 19 |
+
top_k = len(caption_df)
|
| 20 |
+
distances, indices = caption_index.search(query_embedding, top_k)
|
| 21 |
|
| 22 |
similarity_threshold = 0.35
|
| 23 |
selected = [(idx, sim) for idx, sim in zip(indices[0], distances[0]) if sim >= similarity_threshold]
|
|
|
|
| 28 |
# === Format results ===
|
| 29 |
outer_list = []
|
| 30 |
for rank, (idx, sim) in enumerate(selected, 1):
|
| 31 |
+
row = caption_df.iloc[idx]
|
| 32 |
res = {
|
| 33 |
'rank': rank,
|
| 34 |
'username': row['username'],
|
src/genai/ideation_agent/utils/tools.py
CHANGED
|
@@ -7,7 +7,7 @@ import ast
|
|
| 7 |
import faiss
|
| 8 |
import tiktoken
|
| 9 |
from src.genai.utils.models_loader import embedding_model
|
| 10 |
-
from src.genai.utils.load_embeddings import
|
| 11 |
from src.genai.utils.utils import clean_text
|
| 12 |
|
| 13 |
@tool("influencers_data_retrieval_tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
|
|
@@ -19,8 +19,8 @@ def retrieve_tool(business_details):
|
|
| 19 |
query_embedding = np.array(embedding_model.embed_query(str(business_details))).reshape(1, -1).astype('float32')
|
| 20 |
faiss.normalize_L2(query_embedding)
|
| 21 |
|
| 22 |
-
top_k = len(
|
| 23 |
-
distances, indices =
|
| 24 |
|
| 25 |
similarity_threshold = 0.35
|
| 26 |
selected = [(idx, sim) for idx, sim in zip(indices[0], distances[0]) if sim >= similarity_threshold]
|
|
@@ -31,7 +31,7 @@ def retrieve_tool(business_details):
|
|
| 31 |
# === Format results ===
|
| 32 |
outer_list = []
|
| 33 |
for rank, (idx, sim) in enumerate(selected, 1):
|
| 34 |
-
row =
|
| 35 |
res = {
|
| 36 |
'rank': rank,
|
| 37 |
'username': row['username'],
|
|
@@ -52,4 +52,6 @@ def retrieve_tool(business_details):
|
|
| 52 |
encoding = tiktoken.encoding_for_model('gpt-4o-mini')
|
| 53 |
tokens = encoding.encode(cleaned_response)
|
| 54 |
trimmed_response = tokens[:1000]
|
| 55 |
-
return encoding.decode(trimmed_response)
|
|
|
|
|
|
|
|
|
| 7 |
import faiss
|
| 8 |
import tiktoken
|
| 9 |
from src.genai.utils.models_loader import embedding_model
|
| 10 |
+
from src.genai.utils.load_embeddings import caption_embeddings , caption_index , caption_df
|
| 11 |
from src.genai.utils.utils import clean_text
|
| 12 |
|
| 13 |
@tool("influencers_data_retrieval_tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
|
|
|
|
| 19 |
query_embedding = np.array(embedding_model.embed_query(str(business_details))).reshape(1, -1).astype('float32')
|
| 20 |
faiss.normalize_L2(query_embedding)
|
| 21 |
|
| 22 |
+
top_k = len(caption_df)
|
| 23 |
+
distances, indices = caption_index.search(query_embedding, top_k)
|
| 24 |
|
| 25 |
similarity_threshold = 0.35
|
| 26 |
selected = [(idx, sim) for idx, sim in zip(indices[0], distances[0]) if sim >= similarity_threshold]
|
|
|
|
| 31 |
# === Format results ===
|
| 32 |
outer_list = []
|
| 33 |
for rank, (idx, sim) in enumerate(selected, 1):
|
| 34 |
+
row = caption_df.iloc[idx]
|
| 35 |
res = {
|
| 36 |
'rank': rank,
|
| 37 |
'username': row['username'],
|
|
|
|
| 52 |
encoding = tiktoken.encoding_for_model('gpt-4o-mini')
|
| 53 |
tokens = encoding.encode(cleaned_response)
|
| 54 |
trimmed_response = tokens[:1000]
|
| 55 |
+
return encoding.decode(trimmed_response)
|
| 56 |
+
|
| 57 |
+
|
src/genai/orchestration_agent/utils/__pycache__/utils.cpython-312.pyc
CHANGED
|
Binary files a/src/genai/orchestration_agent/utils/__pycache__/utils.cpython-312.pyc and b/src/genai/orchestration_agent/utils/__pycache__/utils.cpython-312.pyc differ
|
|
|
src/genai/orchestration_agent/utils/tools.py
CHANGED
|
@@ -2,7 +2,7 @@ import faiss
|
|
| 2 |
import ast
|
| 3 |
import pandas as pd
|
| 4 |
import numpy as np
|
| 5 |
-
from src.genai.utils.load_embeddings import
|
| 6 |
from src.genai.utils.models_loader import embedding_model
|
| 7 |
from src.genai.utils.utils import clean_text
|
| 8 |
import tiktoken
|
|
@@ -17,16 +17,16 @@ def retrieve_data_for_analytics(business_details):
|
|
| 17 |
# === Encode the query and search ===
|
| 18 |
query_embedding = np.array(embedding_model.embed_query(str(business_details))).reshape(1, -1).astype('float32')
|
| 19 |
top_k = 10
|
| 20 |
-
distances, indices =
|
| 21 |
|
| 22 |
# === Format results ===
|
| 23 |
results = []
|
| 24 |
for i, idx in enumerate(indices[0]):
|
| 25 |
-
likes =
|
| 26 |
-
comments =
|
| 27 |
res = {
|
| 28 |
-
'url':
|
| 29 |
-
'username':
|
| 30 |
'likesCount': int(likes) if pd.notnull(likes) else None,
|
| 31 |
'commentCount': int(comments) if pd.notnull(comments) else None
|
| 32 |
}
|
|
@@ -38,8 +38,8 @@ def retrieve_data_for_orchestration(query):
|
|
| 38 |
query_embedding = np.array(embedding_model.embed_query(str(query))).reshape(1, -1).astype('float32')
|
| 39 |
faiss.normalize_L2(query_embedding)
|
| 40 |
|
| 41 |
-
top_k = len(
|
| 42 |
-
distances, indices =
|
| 43 |
|
| 44 |
similarity_threshold = 0.35
|
| 45 |
selected = [(idx, sim) for idx, sim in zip(indices[0], distances[0]) if sim >= similarity_threshold]
|
|
@@ -50,7 +50,7 @@ def retrieve_data_for_orchestration(query):
|
|
| 50 |
# === Format results ===
|
| 51 |
outer_list = []
|
| 52 |
for rank, (idx, sim) in enumerate(selected, 1):
|
| 53 |
-
row =
|
| 54 |
res = {
|
| 55 |
'rank': rank,
|
| 56 |
'username': row['username'],
|
|
|
|
| 2 |
import ast
|
| 3 |
import pandas as pd
|
| 4 |
import numpy as np
|
| 5 |
+
from src.genai.utils.load_embeddings import caption_df, caption_embeddings , caption_index
|
| 6 |
from src.genai.utils.models_loader import embedding_model
|
| 7 |
from src.genai.utils.utils import clean_text
|
| 8 |
import tiktoken
|
|
|
|
| 17 |
# === Encode the query and search ===
|
| 18 |
query_embedding = np.array(embedding_model.embed_query(str(business_details))).reshape(1, -1).astype('float32')
|
| 19 |
top_k = 10
|
| 20 |
+
distances, indices = caption_index.search(query_embedding, top_k)
|
| 21 |
|
| 22 |
# === Format results ===
|
| 23 |
results = []
|
| 24 |
for i, idx in enumerate(indices[0]):
|
| 25 |
+
likes = caption_df.iloc[idx]['likesCount']
|
| 26 |
+
comments = caption_df.iloc[idx]['commentCount']
|
| 27 |
res = {
|
| 28 |
+
'url': caption_df.iloc[idx]['videoUrl'],
|
| 29 |
+
'username': caption_df.iloc[idx]['username'],
|
| 30 |
'likesCount': int(likes) if pd.notnull(likes) else None,
|
| 31 |
'commentCount': int(comments) if pd.notnull(comments) else None
|
| 32 |
}
|
|
|
|
| 38 |
query_embedding = np.array(embedding_model.embed_query(str(query))).reshape(1, -1).astype('float32')
|
| 39 |
faiss.normalize_L2(query_embedding)
|
| 40 |
|
| 41 |
+
top_k = len(caption_df)
|
| 42 |
+
distances, indices = caption_index.search(query_embedding, top_k)
|
| 43 |
|
| 44 |
similarity_threshold = 0.35
|
| 45 |
selected = [(idx, sim) for idx, sim in zip(indices[0], distances[0]) if sim >= similarity_threshold]
|
|
|
|
| 50 |
# === Format results ===
|
| 51 |
outer_list = []
|
| 52 |
for rank, (idx, sim) in enumerate(selected, 1):
|
| 53 |
+
row = caption_df.iloc[idx]
|
| 54 |
res = {
|
| 55 |
'rank': rank,
|
| 56 |
'username': row['username'],
|
src/genai/orchestration_agent/utils/utils.py
CHANGED
|
@@ -8,7 +8,8 @@ from src.genai.utils.models_loader import llm
|
|
| 8 |
from langchain_core.messages import FunctionMessage , AIMessage
|
| 9 |
from .tools import retrieve_data_for_analytics
|
| 10 |
import re
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
def caption_image(image_base64,user_input):
|
| 14 |
if len(image_base64)>0:
|
|
@@ -42,6 +43,7 @@ def caption_image(image_base64,user_input):
|
|
| 42 |
|
| 43 |
def show_analytics(business_details):
|
| 44 |
tool_response = retrieve_data_for_analytics(str(business_details))
|
|
|
|
| 45 |
return tool_response
|
| 46 |
|
| 47 |
def extract_latest_response_block(response):
|
|
|
|
| 8 |
from langchain_core.messages import FunctionMessage , AIMessage
|
| 9 |
from .tools import retrieve_data_for_analytics
|
| 10 |
import re
|
| 11 |
+
import logging
|
| 12 |
+
app_logger = logging.getLogger("app_logger")
|
| 13 |
|
| 14 |
def caption_image(image_base64,user_input):
|
| 15 |
if len(image_base64)>0:
|
|
|
|
| 43 |
|
| 44 |
def show_analytics(business_details):
|
| 45 |
tool_response = retrieve_data_for_analytics(str(business_details))
|
| 46 |
+
app_logger.info('Showing Analytics of the influencers after context analysis.')
|
| 47 |
return tool_response
|
| 48 |
|
| 49 |
def extract_latest_response_block(response):
|
src/genai/utils/load_embeddings.py
CHANGED
|
@@ -5,7 +5,7 @@ import pandas as pd
|
|
| 5 |
from datasets import load_dataset
|
| 6 |
|
| 7 |
|
| 8 |
-
def
|
| 9 |
dataset = load_dataset("DvorakInnovationAI/rt-genai-dataset-v1", revision="openai-embeddings")
|
| 10 |
df = dataset["train"]
|
| 11 |
df= df.to_pandas()
|
|
@@ -16,6 +16,18 @@ def load_index_once():
|
|
| 16 |
index.add(embeddings)
|
| 17 |
return df, embeddings, index
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
print('Loading Embeddings...........')
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
| 5 |
from datasets import load_dataset
|
| 6 |
|
| 7 |
|
| 8 |
+
def load_caption_index():
|
| 9 |
dataset = load_dataset("DvorakInnovationAI/rt-genai-dataset-v1", revision="openai-embeddings")
|
| 10 |
df = dataset["train"]
|
| 11 |
df= df.to_pandas()
|
|
|
|
| 16 |
index.add(embeddings)
|
| 17 |
return df, embeddings, index
|
| 18 |
|
| 19 |
+
def load_imdb_ideas_index():
|
| 20 |
+
dataset = load_dataset("DvorakInnovationAI/rt-genai-imdb-ideas-v1", revision='openai_embeddings')
|
| 21 |
+
df = dataset['train']
|
| 22 |
+
df= df.to_pandas()
|
| 23 |
+
df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x) if isinstance(x,str) else x)
|
| 24 |
+
embeddings = np.vstack(df['embeddings'].values).astype('float32')
|
| 25 |
+
faiss.normalize_L2(embeddings)
|
| 26 |
+
index = faiss.IndexFlatIP(embeddings.shape[1])
|
| 27 |
+
index.add(embeddings)
|
| 28 |
+
return df , embeddings , index
|
| 29 |
+
|
| 30 |
print('Loading Embeddings...........')
|
| 31 |
+
caption_df, caption_embeddings, caption_index = load_caption_index()
|
| 32 |
+
ideas_df , ideas_embeddings , ideas_index = load_imdb_ideas_index()
|
| 33 |
+
print('Embeddings Loaded.................')
|