File size: 5,503 Bytes
c9531de
 
ed521a5
c9531de
f9147ba
c9531de
f9147ba
c9531de
 
 
 
 
 
 
f9147ba
c9531de
 
 
f9147ba
 
c9531de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9147ba
 
 
c9531de
f9147ba
 
 
 
 
 
b10e29c
f9147ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b10e29c
f9147ba
c9531de
 
 
 
 
 
f9147ba
 
c9531de
 
 
 
 
 
 
 
 
ed521a5
c9531de
ed521a5
c9531de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed521a5
 
c9531de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import httpx
import gradio as gr
from openai import OpenAI
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from fastembed import SparseTextEmbedding

API_KEY = os.environ.get('DEEPSEEK_API_KEY')
BASE_URL = "https://api.deepseek.com"

QDRANT_PATH = "./qdrant_db"
COLLECTION_NAME = "huggingface_transformers_docs"
EMBEDDING_MODEL_ID = "fyerfyer/finetune-jina-transformers-v1"
SPARSE_MODEL_ID = "prithivida/Splade_PP_en_v1"

class HFRAG:
  def __init__(self):
    self.dense_model = SentenceTransformer(EMBEDDING_MODEL_ID, trust_remote_code=True)
    self.sparse_model = SparseTextEmbedding(model_name=SPARSE_MODEL_ID)
    
    lock_file = os.path.join(QDRANT_PATH, ".lock")
    if os.path.exists(lock_file):
      try:
        os.remove(lock_file)
        print("Cleaned up stale lock file.")
      except:
        pass
    
    if not os.path.exists(QDRANT_PATH):
      raise ValueError(f"Qdrant path not found: {QDRANT_PATH}.")

    self.db_client = QdrantClient(path=QDRANT_PATH)
    
    if not self.db_client.collection_exists(COLLECTION_NAME):
      raise ValueError(f"Collection '{COLLECTION_NAME}' not found in Qdrant DB.")
      
    print(f"Connected to Qdrant")

    self.llm_client = OpenAI(
      api_key=API_KEY,
      base_url=BASE_URL,
      http_client=httpx.Client(proxy=None, trust_env=False)
    )

  def retrieve(self, query: str, top_k: int = 5):
    # Generate dense vector
    query_dense_vec = self.dense_model.encode(query).tolist()
    
    # Generate sparse vector
    query_sparse_gen = list(self.sparse_model.embed([query]))[0]
    query_sparse_vec = models.SparseVector(
      indices=query_sparse_gen.indices.tolist(),
      values=query_sparse_gen.values.tolist()
    )
    
    # Create prefetch for dense retrieval
    prefetch_dense = models.Prefetch(
      query=query_dense_vec,
      using="text-dense",
      limit=20,
    )
    
    # Create prefetch for sparse retrieval
    prefetch_sparse = models.Prefetch(
      query=query_sparse_vec,
      using="text-sparse",
      limit=20,
    )
    
    # Hybrid search with RRF fusion
    results = self.db_client.query_points(
      collection_name=COLLECTION_NAME,
      prefetch=[prefetch_dense, prefetch_sparse],
      query=models.FusionQuery(fusion=models.Fusion.RRF),
      limit=top_k,
      with_payload=True
    ).points

    return results

  def format_context(self, search_results):
    context_pieces = []
    sources_summary = []
    
    for idx, hit in enumerate(search_results, 1):
      raw_source = hit.payload.get('source', 'unknown')
      filename = raw_source.split('/')[-1] if '/' in raw_source else raw_source
      text = hit.payload['text']
      score = hit.score
      
      sources_summary.append(f"`{filename}` (Score: {score:.2f})")

      piece = f"""<doc id="{idx}" source="{filename}">\n{text}\n</doc>"""
      context_pieces.append(piece)
      
    return "\n\n".join(context_pieces), sources_summary

rag_system = None

def initialize_system():
  global rag_system
  if rag_system is None:
    try:
      rag_system = HFRAG()
    except Exception as e:
      print(f"Error initializing: {e}")
      return None
  return rag_system

# ================= Gradio Logic =================
def predict(message, history):
  rag = initialize_system()
  
  if not rag:
    yield "❌ System initialization failed. Check logs."
    return
  
  if not API_KEY:
    yield "❌ Error: `DEEPSEEK_API_KEY` not set in Space secrets."
    return

  # 1. Retrieve
  yield "πŸ” Retrieving relevant documents..."
  results = rag.retrieve(message)
  
  if not results:
    yield "⚠️ No relevant documents found in the knowledge base."
    return

  # 2. Format context
  context_str, sources_list = rag.format_context(results)
  
  # 3. Build Prompt
  system_prompt = """You are an expert AI assistant specializing in the Hugging Face Transformers library.
Your goal is to answer the user's question based ONLY on the provided "Retrieved Context".

GUIDELINES:
1. **Code First**: Prioritize showing Python code examples.
2. **Citation**: Cite source filenames like `[model_doc.md]`.
3. **Honesty**: If the answer isn't in the context, say you don't know.
4. **Format**: Use Markdown."""

  user_prompt = f"""### User Query\n{message}\n\n### Retrieved Context\n{context_str}"""

  header = "**πŸ“š Found relevant documents:**\n" + "\n".join([f"- {s}" for s in sources_list]) + "\n\n---\n\n"
  current_response = header
  yield current_response

  try:
    response = rag.llm_client.chat.completions.create(
      model="deepseek-chat",
      messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
      ],
      temperature=0.1,
      stream=True
    )
    
    for chunk in response:
      if chunk.choices[0].delta.content:
        content = chunk.choices[0].delta.content
        current_response += content
        yield current_response
        
  except Exception as e:
    yield current_response + f"\n\n❌ LLM API Error: {str(e)}"

demo = gr.ChatInterface(
  fn=predict,
  title="πŸ€— Hugging Face RAG Expert",
  description="Ask me anything about Transformers! Powered by DeepSeek-V3 & Finetuned Embeddings.",
  examples=[
    "How to implement padding?",
    "How to use BERT pipeline?", 
    "How to fine-tune a model using Trainer?",
    "What is the difference between padding and truncation?"
  ],
  theme="soft"
)

if __name__ == "__main__":
  demo.launch()