Spaces:
Build error
Build error
| import time | |
| import fitz | |
| import numpy as np | |
| import dill | |
| import os | |
| import logging | |
| import asyncio | |
| import networkx as nx | |
| from annoy import AnnoyIndex | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from rank_bm25 import BM25Okapi | |
| from gensim.models import Word2Vec | |
| from typing import List, Optional, Tuple | |
| import gradio as gr | |
| from openai import OpenAI | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| PDF_PATH = "input.pdf" | |
| VECTOR_DB_PATH = "vector_db.pkl" | |
| ANNOY_INDEX_PATH = "vector_index.ann" | |
| def get_text_embedding_with_rate_limit(text_list, initial_delay=2, max_retries=10, max_delay=60): | |
| embeddings = [] | |
| for text in text_list: | |
| retries = 0 | |
| delay = initial_delay | |
| while retries < max_retries: | |
| try: | |
| if len(text) > 8192: | |
| logging.warning("Text chunk exceeds the token limit. Truncating the text.") | |
| text = text[:8192] | |
| response = client.embeddings.create( | |
| model="text-embedding-3-small", | |
| input=[text] | |
| ) | |
| embeddings.append(response.data[0].embedding) | |
| time.sleep(delay) | |
| break | |
| except Exception as e: | |
| retries += 1 | |
| logging.warning(f"Embedding retry {retries}/{max_retries} after error: {e}") | |
| time.sleep(delay) | |
| delay = min(delay * 2, max_delay) | |
| if retries == max_retries: | |
| logging.error("Max retries reached. Skipping this chunk.") | |
| return embeddings | |
| def split_text_into_chunks(text: str, chunk_size: int = 2048, overlap: int = 200) -> List[str]: | |
| tokens = text.split() | |
| chunks = [] | |
| start = 0 | |
| while start < len(tokens): | |
| end = start + chunk_size | |
| chunk = " ".join(tokens[start:end]) | |
| chunks.append(chunk) | |
| start += chunk_size - overlap | |
| return chunks | |
| def store_embeddings_in_vector_db(pdf_path, vector_db_path, annoy_index_path, chunk_size=2048, overlap=200, num_trees=10): | |
| doc = fitz.open(pdf_path) | |
| all_embeddings = [] | |
| all_texts = [] | |
| for page_num in range(doc.page_count): | |
| text = doc.load_page(page_num).get_text() | |
| if text.strip(): | |
| chunks = split_text_into_chunks(text, chunk_size, overlap) | |
| embeddings = get_text_embedding_with_rate_limit(chunks) | |
| all_embeddings.extend(embeddings) | |
| all_texts.extend(chunks) | |
| embeddings_np = np.array(all_embeddings).astype('float32') | |
| with open(vector_db_path, "wb") as f: | |
| dill.dump({'embeddings': embeddings_np, 'texts': all_texts}, f) | |
| if os.path.exists(annoy_index_path): | |
| os.remove(annoy_index_path) | |
| embedding_dim = embeddings_np.shape[1] | |
| annoy_index = AnnoyIndex(embedding_dim, 'angular') | |
| for i, embedding in enumerate(embeddings_np): | |
| annoy_index.add_item(i, embedding) | |
| annoy_index.build(num_trees) | |
| annoy_index.save(annoy_index_path) | |
| if not os.path.exists(VECTOR_DB_PATH) or not os.path.exists(ANNOY_INDEX_PATH): | |
| store_embeddings_in_vector_db(PDF_PATH, VECTOR_DB_PATH, ANNOY_INDEX_PATH) | |
| class MistralRAGChatbot: | |
| def __init__(self, vector_db_path: str, annoy_index_path: str): | |
| with open(vector_db_path, "rb") as f: | |
| data = dill.load(f) | |
| self.embeddings = np.array(data['embeddings'], dtype='float32') | |
| self.texts = data['texts'] | |
| self.annoy_index = AnnoyIndex(self.embeddings.shape[1], 'angular') | |
| self.annoy_index.load(annoy_index_path) | |
| self.bm25 = BM25Okapi([text.split() for text in self.texts]) | |
| self.word2vec_model = Word2Vec([text.split() for text in self.texts], vector_size=100, window=5, min_count=1, workers=4) | |
| def get_text_embedding(self, text: str) -> np.ndarray: | |
| try: | |
| response = client.embeddings.create( | |
| model="text-embedding-3-small", | |
| input=[text] | |
| ) | |
| return np.array(response.data[0].embedding, dtype=np.float32) | |
| except Exception as e: | |
| logging.error(f"Error fetching embedding: {e}") | |
| return np.zeros((1536,), dtype=np.float32) | |
| def retrieve_documents(self, query: str, embedding: np.ndarray, top_k=10): | |
| indices, distances = self.annoy_index.get_nns_by_vector(embedding, top_k, include_distances=True) | |
| bm25_scores = self.bm25.get_scores(query.split()) | |
| combined_docs = [] | |
| for idx in indices: | |
| combined_docs.append({ | |
| 'text': self.texts[idx], | |
| 'method': 'hybrid', | |
| 'score': float(bm25_scores[idx]), | |
| 'index': idx | |
| }) | |
| return combined_docs | |
| def rerank_documents(self, query: str, docs: List[dict]) -> List[dict]: | |
| query_embedding = self.get_text_embedding(query) | |
| vector_scores = {doc['index']: doc['score'] for doc in docs} | |
| sim_graph = nx.Graph() | |
| sim_matrix = cosine_similarity(self.embeddings) | |
| for i in range(len(self.embeddings)): | |
| for j in range(i + 1, len(self.embeddings)): | |
| if sim_matrix[i, j] > 0.5: | |
| sim_graph.add_edge(i, j, weight=sim_matrix[i, j]) | |
| pagerank_scores = np.array(list(nx.pagerank(sim_graph, weight='weight').values())) | |
| for doc in docs: | |
| idx = doc['index'] | |
| doc['score'] = 0.7 * vector_scores.get(idx, 0) + 0.3 * pagerank_scores[idx] | |
| return sorted(docs, key=lambda x: x['score'], reverse=True)[:5] | |
| def build_prompt(self, context: str, query: str, style: str) -> str: | |
| styles = { | |
| "detailed": "Provide a detailed answer.", | |
| "concise": "Provide a concise answer.", | |
| "creative": "Be creative in your response.", | |
| "technical": "Provide a technically sound answer." | |
| } | |
| instruction = styles.get(style.lower(), styles["detailed"]) | |
| return f"""You are a helpful assistant.\nContext:\n{context}\nQuestion:\n{query}\nInstruction:\n{instruction}""" | |
| def generate_response(self, query: str, style: str) -> str: | |
| query_embedding = self.get_text_embedding(query) | |
| docs = self.retrieve_documents(query, query_embedding) | |
| reranked_docs = self.rerank_documents(query, docs) | |
| context = "\n\n".join([doc['text'] for doc in reranked_docs]) | |
| prompt = self.build_prompt(context, query, style) | |
| try: | |
| response = "" | |
| stream = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[{"role": "user", "content": prompt}], | |
| stream=True | |
| ) | |
| for chunk in stream: | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| response += chunk.choices[0].delta.content | |
| return response | |
| except Exception as e: | |
| logging.error(f"Error generating response: {e}") | |
| return "Sorry, I couldn't generate a response." | |
| def chatbot_interface(user_query, response_style): | |
| bot = MistralRAGChatbot(VECTOR_DB_PATH, ANNOY_INDEX_PATH) | |
| return bot.generate_response(user_query, response_style) | |
| iface = gr.Blocks(theme="Rabbitt-AI/ChanceRAG") | |
| with iface: | |
| gr.Image("images/ChatHapi_logo.png", label="Image", show_label=False) | |
| gr.Interface( | |
| fn=chatbot_interface, | |
| theme="Rabbitt-AI/ChanceRAG", | |
| inputs=[ | |
| gr.Textbox(lines=5, label="User Query"), | |
| gr.Dropdown(["Detailed", "Concise", "Creative", "Technical"], label="Response Style"), | |
| ], | |
| outputs=gr.Textbox(label="ChatHapi Response"), | |
| ) | |
| iface.launch(share=True) | |