Spaces:
Sleeping
Sleeping
File size: 5,503 Bytes
c9531de ed521a5 c9531de f9147ba c9531de f9147ba c9531de f9147ba c9531de f9147ba c9531de f9147ba c9531de f9147ba b10e29c f9147ba b10e29c f9147ba c9531de f9147ba c9531de ed521a5 c9531de ed521a5 c9531de ed521a5 c9531de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import os
import httpx
import gradio as gr
from openai import OpenAI
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from fastembed import SparseTextEmbedding
API_KEY = os.environ.get('DEEPSEEK_API_KEY')
BASE_URL = "https://api.deepseek.com"
QDRANT_PATH = "./qdrant_db"
COLLECTION_NAME = "huggingface_transformers_docs"
EMBEDDING_MODEL_ID = "fyerfyer/finetune-jina-transformers-v1"
SPARSE_MODEL_ID = "prithivida/Splade_PP_en_v1"
class HFRAG:
def __init__(self):
self.dense_model = SentenceTransformer(EMBEDDING_MODEL_ID, trust_remote_code=True)
self.sparse_model = SparseTextEmbedding(model_name=SPARSE_MODEL_ID)
lock_file = os.path.join(QDRANT_PATH, ".lock")
if os.path.exists(lock_file):
try:
os.remove(lock_file)
print("Cleaned up stale lock file.")
except:
pass
if not os.path.exists(QDRANT_PATH):
raise ValueError(f"Qdrant path not found: {QDRANT_PATH}.")
self.db_client = QdrantClient(path=QDRANT_PATH)
if not self.db_client.collection_exists(COLLECTION_NAME):
raise ValueError(f"Collection '{COLLECTION_NAME}' not found in Qdrant DB.")
print(f"Connected to Qdrant")
self.llm_client = OpenAI(
api_key=API_KEY,
base_url=BASE_URL,
http_client=httpx.Client(proxy=None, trust_env=False)
)
def retrieve(self, query: str, top_k: int = 5):
# Generate dense vector
query_dense_vec = self.dense_model.encode(query).tolist()
# Generate sparse vector
query_sparse_gen = list(self.sparse_model.embed([query]))[0]
query_sparse_vec = models.SparseVector(
indices=query_sparse_gen.indices.tolist(),
values=query_sparse_gen.values.tolist()
)
# Create prefetch for dense retrieval
prefetch_dense = models.Prefetch(
query=query_dense_vec,
using="text-dense",
limit=20,
)
# Create prefetch for sparse retrieval
prefetch_sparse = models.Prefetch(
query=query_sparse_vec,
using="text-sparse",
limit=20,
)
# Hybrid search with RRF fusion
results = self.db_client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[prefetch_dense, prefetch_sparse],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=top_k,
with_payload=True
).points
return results
def format_context(self, search_results):
context_pieces = []
sources_summary = []
for idx, hit in enumerate(search_results, 1):
raw_source = hit.payload.get('source', 'unknown')
filename = raw_source.split('/')[-1] if '/' in raw_source else raw_source
text = hit.payload['text']
score = hit.score
sources_summary.append(f"`{filename}` (Score: {score:.2f})")
piece = f"""<doc id="{idx}" source="{filename}">\n{text}\n</doc>"""
context_pieces.append(piece)
return "\n\n".join(context_pieces), sources_summary
rag_system = None
def initialize_system():
global rag_system
if rag_system is None:
try:
rag_system = HFRAG()
except Exception as e:
print(f"Error initializing: {e}")
return None
return rag_system
# ================= Gradio Logic =================
def predict(message, history):
rag = initialize_system()
if not rag:
yield "β System initialization failed. Check logs."
return
if not API_KEY:
yield "β Error: `DEEPSEEK_API_KEY` not set in Space secrets."
return
# 1. Retrieve
yield "π Retrieving relevant documents..."
results = rag.retrieve(message)
if not results:
yield "β οΈ No relevant documents found in the knowledge base."
return
# 2. Format context
context_str, sources_list = rag.format_context(results)
# 3. Build Prompt
system_prompt = """You are an expert AI assistant specializing in the Hugging Face Transformers library.
Your goal is to answer the user's question based ONLY on the provided "Retrieved Context".
GUIDELINES:
1. **Code First**: Prioritize showing Python code examples.
2. **Citation**: Cite source filenames like `[model_doc.md]`.
3. **Honesty**: If the answer isn't in the context, say you don't know.
4. **Format**: Use Markdown."""
user_prompt = f"""### User Query\n{message}\n\n### Retrieved Context\n{context_str}"""
header = "**π Found relevant documents:**\n" + "\n".join([f"- {s}" for s in sources_list]) + "\n\n---\n\n"
current_response = header
yield current_response
try:
response = rag.llm_client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.1,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
current_response += content
yield current_response
except Exception as e:
yield current_response + f"\n\nβ LLM API Error: {str(e)}"
demo = gr.ChatInterface(
fn=predict,
title="π€ Hugging Face RAG Expert",
description="Ask me anything about Transformers! Powered by DeepSeek-V3 & Finetuned Embeddings.",
examples=[
"How to implement padding?",
"How to use BERT pipeline?",
"How to fine-tune a model using Trainer?",
"What is the difference between padding and truncation?"
],
theme="soft"
)
if __name__ == "__main__":
demo.launch() |