Ariyan-Pro commited on
Commit
eed2a86
·
verified ·
1 Parent(s): a491868

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -32
app.py CHANGED
@@ -3,46 +3,81 @@ import time
3
  import sys
4
  import os
5
 
6
- # Add the repo root to path so we can import the /app modules
7
  sys.path.append(os.path.dirname(__file__))
8
 
9
- # Import your three RAG implementations
10
- from app.rag_naive import NaiveRAG
11
- from app.rag_optimized import OptimizedRAG
12
- from app.no_compromise_rag import NoCompromiseRAG
 
13
 
14
- # -------------------------------------------------------------------
15
- # Initialize the three RAG systems once at startup.
16
- # If memory becomes an issue, we can lazy‑load them inside each function.
17
- # -------------------------------------------------------------------
18
- print("Initializing Naive RAG...")
19
- naive_rag = NaiveRAG() # loads embedding model + FAISS index
20
- print("Initializing Optimized RAG...")
21
- optimized_rag = OptimizedRAG() # loads the same + SQLite cache
22
- print("Initializing No‑Compromise RAG...")
23
- no_compromise_rag = NoCompromiseRAG()
24
- print("All RAG systems ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # -------------------------------------------------------------------
27
- # Define the query functions for each mode
28
- # -------------------------------------------------------------------
29
  def query_naive(question):
30
- start = time.perf_counter()
31
- answer, chunks_used, cache_hit = naive_rag.query(question)
32
- latency = (time.perf_counter() - start) * 1000
33
- return answer, f"{latency:.1f} ms", chunks_used, "Yes" if cache_hit else "No"
 
 
 
 
34
 
35
  def query_optimized(question):
36
- start = time.perf_counter()
37
- answer, chunks_used, cache_hit = optimized_rag.query(question)
38
- latency = (time.perf_counter() - start) * 1000
39
- return answer, f"{latency:.1f} ms", chunks_used, "Yes" if cache_hit else "No"
 
 
 
 
40
 
41
  def query_no_compromise(question):
42
- start = time.perf_counter()
43
- answer, chunks_used, cache_hit = no_compromise_rag.query(question)
44
- latency = (time.perf_counter() - start) * 1000
45
- return answer, f"{latency:.1f} ms", chunks_used, "Yes" if cache_hit else "No"
 
 
 
 
46
 
47
  # -------------------------------------------------------------------
48
  # Build the Gradio interface
@@ -112,6 +147,5 @@ with gr.Blocks(title="RAG Latency Optimization", theme=gr.themes.Soft()) as demo
112
  **Caching**: SQLite (Optimized) + LRU memory | **Generation**: Simulated (real LLM can be plugged in)
113
  """)
114
 
115
- # Launch the app
116
  if __name__ == "__main__":
117
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
3
  import sys
4
  import os
5
 
6
+ # Add repo root to path
7
  sys.path.append(os.path.dirname(__file__))
8
 
9
+ # Global references to loaded systems
10
+ _naive_rag = None
11
+ _optimized_rag = None
12
+ _no_compromise_rag = None
13
+ _embedding_model = None # shared model
14
 
15
+ def get_embedding_model():
16
+ """Load the embedding model once and reuse it across all RAG classes."""
17
+ global _embedding_model
18
+ if _embedding_model is None:
19
+ from sentence_transformers import SentenceTransformer
20
+ print("Loading embedding model...")
21
+ _embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
22
+ return _embedding_model
23
+
24
+ def get_naive():
25
+ global _naive_rag
26
+ if _naive_rag is None:
27
+ from app.rag_naive import NaiveRAG
28
+ print("Initializing Naive RAG...")
29
+ # Pass the shared embedding model if the class supports it
30
+ # (you may need to modify your RAG classes to accept a model argument)
31
+ _naive_rag = NaiveRAG()
32
+ # If NaiveRAG has a set_embedding_model method, call it:
33
+ # _naive_rag.set_embedding_model(get_embedding_model())
34
+ return _naive_rag
35
+
36
+ def get_optimized():
37
+ global _optimized_rag
38
+ if _optimized_rag is None:
39
+ from app.rag_optimized import OptimizedRAG
40
+ print("Initializing Optimized RAG...")
41
+ _optimized_rag = OptimizedRAG()
42
+ return _optimized_rag
43
+
44
+ def get_no_compromise():
45
+ global _no_compromise_rag
46
+ if _no_compromise_rag is None:
47
+ from app.no_compromise_rag import NoCompromiseRAG
48
+ print("Initializing No-Compromise RAG...")
49
+ _no_compromise_rag = NoCompromiseRAG()
50
+ return _no_compromise_rag
51
 
 
 
 
52
  def query_naive(question):
53
+ try:
54
+ rag = get_naive()
55
+ start = time.perf_counter()
56
+ answer, chunks_used, cache_hit = rag.query(question)
57
+ latency = (time.perf_counter() - start) * 1000
58
+ return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No"
59
+ except Exception as e:
60
+ return f"Error: {e}", "0 ms", "0", "No"
61
 
62
  def query_optimized(question):
63
+ try:
64
+ rag = get_optimized()
65
+ start = time.perf_counter()
66
+ answer, chunks_used, cache_hit = rag.query(question)
67
+ latency = (time.perf_counter() - start) * 1000
68
+ return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No"
69
+ except Exception as e:
70
+ return f"Error: {e}", "0 ms", "0", "No"
71
 
72
  def query_no_compromise(question):
73
+ try:
74
+ rag = get_no_compromise()
75
+ start = time.perf_counter()
76
+ answer, chunks_used, cache_hit = rag.query(question)
77
+ latency = (time.perf_counter() - start) * 1000
78
+ return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No"
79
+ except Exception as e:
80
+ return f"Error: {e}", "0 ms", "0", "No"
81
 
82
  # -------------------------------------------------------------------
83
  # Build the Gradio interface
 
147
  **Caching**: SQLite (Optimized) + LRU memory | **Generation**: Simulated (real LLM can be plugged in)
148
  """)
149
 
 
150
  if __name__ == "__main__":
151
  demo.launch(server_name="0.0.0.0", server_port=7860)