Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- app.py +0 -3
- helpers.py +25 -15
- requirements.txt +1 -2
app.py
CHANGED
|
@@ -14,7 +14,6 @@ load_dotenv()
|
|
| 14 |
# Function to initialize APIs
|
| 15 |
def initialize_apis():
|
| 16 |
if "openai_api_key" in st.session_state and "cohere_api_key" in st.session_state:
|
| 17 |
-
openai.api_key = st.session_state["openai_api_key"]
|
| 18 |
co = cohere.Client(st.session_state["cohere_api_key"])
|
| 19 |
index = helpers.initialize_pinecone(
|
| 20 |
st.session_state["api_key"], st.session_state["env"], "coherererank", 1536
|
|
@@ -86,5 +85,3 @@ if all(
|
|
| 86 |
st.warning(error)
|
| 87 |
else:
|
| 88 |
st.warning("Please enter a query.")
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 14 |
# Function to initialize APIs
|
| 15 |
def initialize_apis():
|
| 16 |
if "openai_api_key" in st.session_state and "cohere_api_key" in st.session_state:
|
|
|
|
| 17 |
co = cohere.Client(st.session_state["cohere_api_key"])
|
| 18 |
index = helpers.initialize_pinecone(
|
| 19 |
st.session_state["api_key"], st.session_state["env"], "coherererank", 1536
|
|
|
|
| 85 |
st.warning(error)
|
| 86 |
else:
|
| 87 |
st.warning("Please enter a query.")
|
|
|
|
|
|
helpers.py
CHANGED
|
@@ -1,8 +1,16 @@
|
|
| 1 |
import random
|
| 2 |
import time
|
|
|
|
| 3 |
|
| 4 |
import faker
|
| 5 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import pinecone
|
| 7 |
import tqdm
|
| 8 |
from datasets import Dataset
|
|
@@ -79,10 +87,10 @@ def create_dataset(num_resumes=1000, chunk_size=800):
|
|
| 79 |
|
| 80 |
def embed(docs: list[str]) -> list[list[float]]:
|
| 81 |
print("Embedding documents...")
|
| 82 |
-
res =
|
| 83 |
print("Documents embedded successfully!")
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
|
| 88 |
def insert_to_pinecone(index, dataset, batch_size=100):
|
|
@@ -117,12 +125,12 @@ def insert_to_pinecone(index, dataset, batch_size=100):
|
|
| 117 |
|
| 118 |
print("New data inserted to Pinecone successfully!")
|
| 119 |
|
| 120 |
-
|
| 121 |
def get_docs(index, query: str, top_k: int):
|
| 122 |
print("Fetching documents from Pinecone...")
|
| 123 |
xq = embed([query])[0]
|
| 124 |
res = index.query(xq, top_k=top_k, include_metadata=True)
|
| 125 |
-
docs = {x["metadata"]["text"]: i for i, x in enumerate(res
|
| 126 |
print("Documents fetched successfully!")
|
| 127 |
return docs
|
| 128 |
|
|
@@ -131,7 +139,7 @@ def compare(index, co, query, top_k=25, top_n=3):
|
|
| 131 |
# Get vec search results
|
| 132 |
docs = get_docs(index, query, top_k=top_k)
|
| 133 |
i2doc = {docs[doc]: doc for doc in docs.keys()}
|
| 134 |
-
|
| 135 |
# Re-rank
|
| 136 |
rerank_docs = co.rerank(
|
| 137 |
query=query,
|
|
@@ -139,18 +147,20 @@ def compare(index, co, query, top_k=25, top_n=3):
|
|
| 139 |
top_n=top_n,
|
| 140 |
model="rerank-english-v2.0",
|
| 141 |
)
|
| 142 |
-
|
| 143 |
comparison_data = []
|
| 144 |
# Compare order change
|
| 145 |
for i, doc in enumerate(rerank_docs):
|
| 146 |
rerank_i = docs[doc.document["text"]]
|
| 147 |
-
|
| 148 |
-
comparison_data.append(
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
return comparison_data
|
| 155 |
|
| 156 |
|
|
|
|
| 1 |
import random
|
| 2 |
import time
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
import faker
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 13 |
+
|
| 14 |
import pinecone
|
| 15 |
import tqdm
|
| 16 |
from datasets import Dataset
|
|
|
|
| 87 |
|
| 88 |
def embed(docs: list[str]) -> list[list[float]]:
|
| 89 |
print("Embedding documents...")
|
| 90 |
+
res = client.embeddings.create(input=docs, model="text-embedding-3-small")
|
| 91 |
print("Documents embedded successfully!")
|
| 92 |
+
# Assuming the new API response object exposes the embedding directly
|
| 93 |
+
return [x.embedding for x in res.data]
|
| 94 |
|
| 95 |
|
| 96 |
def insert_to_pinecone(index, dataset, batch_size=100):
|
|
|
|
| 125 |
|
| 126 |
print("New data inserted to Pinecone successfully!")
|
| 127 |
|
| 128 |
+
|
| 129 |
def get_docs(index, query: str, top_k: int):
|
| 130 |
print("Fetching documents from Pinecone...")
|
| 131 |
xq = embed([query])[0]
|
| 132 |
res = index.query(xq, top_k=top_k, include_metadata=True)
|
| 133 |
+
docs = {x["metadata"]["text"]: i for i, x in enumerate(res.matches)}
|
| 134 |
print("Documents fetched successfully!")
|
| 135 |
return docs
|
| 136 |
|
|
|
|
| 139 |
# Get vec search results
|
| 140 |
docs = get_docs(index, query, top_k=top_k)
|
| 141 |
i2doc = {docs[doc]: doc for doc in docs.keys()}
|
| 142 |
+
|
| 143 |
# Re-rank
|
| 144 |
rerank_docs = co.rerank(
|
| 145 |
query=query,
|
|
|
|
| 147 |
top_n=top_n,
|
| 148 |
model="rerank-english-v2.0",
|
| 149 |
)
|
| 150 |
+
|
| 151 |
comparison_data = []
|
| 152 |
# Compare order change
|
| 153 |
for i, doc in enumerate(rerank_docs):
|
| 154 |
rerank_i = docs[doc.document["text"]]
|
| 155 |
+
|
| 156 |
+
comparison_data.append(
|
| 157 |
+
{
|
| 158 |
+
"Original Rank": i,
|
| 159 |
+
"Original Text": i2doc[i],
|
| 160 |
+
"Reranked Rank": rerank_i,
|
| 161 |
+
"Reranked Text": doc.document["text"],
|
| 162 |
+
}
|
| 163 |
+
)
|
| 164 |
return comparison_data
|
| 165 |
|
| 166 |
|
requirements.txt
CHANGED
|
@@ -63,7 +63,6 @@ jupyter_core==5.4.0
|
|
| 63 |
langchain==0.0.325
|
| 64 |
langsmith==0.0.53
|
| 65 |
Levenshtein==0.23.0
|
| 66 |
-
llama-index==0.8.53.post3
|
| 67 |
loguru==0.7.2
|
| 68 |
markdown-it-py==3.0.0
|
| 69 |
MarkupSafe==2.1.3
|
|
@@ -76,7 +75,7 @@ mypy-extensions==1.0.0
|
|
| 76 |
nest-asyncio==1.5.8
|
| 77 |
nltk==3.8.1
|
| 78 |
numpy==1.26.1
|
| 79 |
-
openai==0.28.
|
| 80 |
openpyxl==3.1.2
|
| 81 |
packaging==23.2
|
| 82 |
pandas==2.1.2
|
|
|
|
| 63 |
langchain==0.0.325
|
| 64 |
langsmith==0.0.53
|
| 65 |
Levenshtein==0.23.0
|
|
|
|
| 66 |
loguru==0.7.2
|
| 67 |
markdown-it-py==3.0.0
|
| 68 |
MarkupSafe==2.1.3
|
|
|
|
| 75 |
nest-asyncio==1.5.8
|
| 76 |
nltk==3.8.1
|
| 77 |
numpy==1.26.1
|
| 78 |
+
openai==0.28.0
|
| 79 |
openpyxl==3.1.2
|
| 80 |
packaging==23.2
|
| 81 |
pandas==2.1.2
|