VietCat commited on
Commit
7001450
·
1 Parent(s): 3c1e19a

init project

Browse files
Files changed (2) hide show
  1. app.py +35 -23
  2. rag_core/retriever.py +2 -0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import logging
3
  from fastapi import FastAPI, Request
@@ -6,20 +7,14 @@ from rag_core.embedder import get_embedding
6
  from rag_core.retriever import Retriever
7
  from rag_core.llm import generate_answer
8
 
9
- app = FastAPI()
10
  retriever = Retriever()
 
 
11
 
12
- # Khởi tạo nếu chưa có index
13
- if retriever.index is None:
14
- logging.info("Không tìm thấy FAISS index, bắt đầu xử lý...")
15
- with open("data/raw_law.txt", "r", encoding="utf-8") as f:
16
- text = f.read()
17
- chunks = chunk_legal_text(text)
18
- retriever.build(chunks, get_embedding)
19
-
20
- # API endpoint
21
  @app.post("/ask")
22
  async def ask_api(req: Request):
 
 
23
  data = await req.json()
24
  query = data.get("query")
25
  docs = retriever.query(query, get_embedding)
@@ -29,25 +24,42 @@ async def ask_api(req: Request):
29
 
30
  @app.post("/rescan")
31
  async def rescan_api():
 
 
32
  with open("data/raw_law.txt", "r", encoding="utf-8") as f:
33
  text = f.read()
34
  chunks = chunk_legal_text(text)
35
  retriever.rescan_and_append(chunks, get_embedding)
36
  return {"status": "Rescan & update thành công."}
37
 
38
- # Gradio UI
39
- iface = gr.Interface(
40
- fn=lambda q: generate_answer("\n\n".join(retriever.query(q, get_embedding)) + f"\n\nCâu hỏi: {q}\nTrả lời:"),
41
- inputs=gr.Textbox(label="Nhập câu hỏi"),
42
- outputs=gr.Textbox(label="Trả lời"),
43
- title="Luật Giao Thông RAG"
44
- )
 
 
 
 
 
 
 
 
 
 
45
 
46
- import uvicorn
47
- import threading
 
 
 
 
48
 
49
- def start_fastapi():
50
- uvicorn.run(app, host="0.0.0.0", port=7861)
51
 
52
- threading.Thread(target=start_fastapi).start()
53
- iface.launch()
 
 
1
+ import os
2
  import gradio as gr
3
  import logging
4
  from fastapi import FastAPI, Request
 
7
  from rag_core.retriever import Retriever
8
  from rag_core.llm import generate_answer
9
 
 
10
  retriever = Retriever()
11
+ app = FastAPI()
12
+ ready = retriever.index is not None
13
 
 
 
 
 
 
 
 
 
 
14
  @app.post("/ask")
15
  async def ask_api(req: Request):
16
+ if not ready:
17
+ return {"error": "Index chưa sẵn sàng. Vui lòng thử lại sau."}
18
  data = await req.json()
19
  query = data.get("query")
20
  docs = retriever.query(query, get_embedding)
 
24
 
25
  @app.post("/rescan")
26
  async def rescan_api():
27
+ if not ready:
28
+ return {"error": "Index chưa sẵn sàng."}
29
  with open("data/raw_law.txt", "r", encoding="utf-8") as f:
30
  text = f.read()
31
  chunks = chunk_legal_text(text)
32
  retriever.rescan_and_append(chunks, get_embedding)
33
  return {"status": "Rescan & update thành công."}
34
 
35
+ def build_index_ui():
36
+ global ready
37
+ with gr.Textbox(visible=False):
38
+ pass # trigger
39
+ with open("data/raw_law.txt", "r", encoding="utf-8") as f:
40
+ text = f.read()
41
+ chunks = chunk_legal_text(text)
42
+ retriever.build(chunks, get_embedding)
43
+ ready = True
44
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
45
+
46
+ def answer_fn(query):
47
+ if not ready:
48
+ return "Index chưa sẵn sàng. Vui lòng chờ hoàn tất xử lý."
49
+ docs = retriever.query(query, get_embedding)
50
+ prompt = "\n\n".join(docs) + f"\n\nCâu hỏi: {query}\nTrả lời:"
51
+ return generate_answer(prompt)
52
 
53
+ with gr.Blocks() as iface:
54
+ build_btn = gr.Button("🔄 Xây Index", visible=not ready)
55
+ query_box = gr.Textbox(label="Nhập câu hỏi", visible=ready)
56
+ output_box = gr.Textbox(label="Trả lời", visible=ready)
57
+ query_box.submit(fn=answer_fn, inputs=query_box, outputs=output_box)
58
+ build_btn.click(fn=build_index_ui, outputs=[build_btn, query_box, output_box])
59
 
60
+ # mount Gradio lên FastAPI
61
+ app = gr.mount_gradio_app(app, iface, path="/")
62
 
63
+ if __name__ == "__main__":
64
+ import uvicorn
65
+ uvicorn.run(app, host="0.0.0.0", port=7860)
rag_core/retriever.py CHANGED
@@ -20,6 +20,8 @@ class Retriever:
20
 
21
  @log_timed("xây FAISS index")
22
  def build(self, texts: list, embed_fn):
 
 
23
  embeddings = []
24
  valid_texts = []
25
  for i, t in enumerate(texts):
 
20
 
21
  @log_timed("xây FAISS index")
22
  def build(self, texts: list, embed_fn):
23
+ os.makedirs(os.path.dirname(INDEX_PATH), exist_ok=True)
24
+
25
  embeddings = []
26
  valid_texts = []
27
  for i, t in enumerate(texts):