Bob-Potato commited on
Commit
a2253c7
·
verified ·
1 Parent(s): 2ffe70e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -34
app.py CHANGED
@@ -3,13 +3,10 @@ import json
3
  import faiss
4
  import pickle
5
  import numpy as np
6
- from fastapi import FastAPI, HTTPException
7
- from pydantic import BaseModel
8
  from sentence_transformers import SentenceTransformer
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
 
11
- app = FastAPI(title="MetaGPT AI - Local Q&A")
12
-
13
  # Config
14
  DATA_DIR = "data"
15
  INDEX_FILE = os.path.join(DATA_DIR, "index.faiss")
@@ -20,7 +17,7 @@ JSON_FILE = "articles.json"
20
 
21
  os.makedirs(DATA_DIR, exist_ok=True)
22
 
23
- # Load models
24
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
25
  gen_model_name = "google/flan-t5-small"
26
  tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
@@ -52,12 +49,10 @@ def load_index():
52
  metadata = pickle.load(f)
53
  return index, metadata
54
 
55
- # ---- Build / Rebuild index from JSON ----
56
- @app.post("/build_index")
57
  def build_index():
58
  if not os.path.exists(JSON_FILE):
59
- raise HTTPException(status_code=404, detail=f"{JSON_FILE} not found")
60
-
61
  with open(JSON_FILE, "r", encoding="utf-8") as f:
62
  articles = json.load(f)
63
 
@@ -79,37 +74,34 @@ def build_index():
79
  metadata = {"texts": texts, "metas": metas}
80
 
81
  save_index(index, metadata)
82
- return {"status": "ok", "total_chunks": len(texts)}
83
-
84
- # ---- Ask endpoint ----
85
- class AskRequest(BaseModel):
86
- question: str
87
- top_k: int = 4
88
- max_answer_tokens: int = 256
89
 
90
- @app.post("/ask")
91
- def ask(req: AskRequest):
92
  index, metadata = load_index()
93
  if index is None:
94
- raise HTTPException(status_code=404, detail="No index found. Call /build_index first.")
95
-
96
- q_emb = embed_model.encode([req.question], convert_to_numpy=True).astype("float32")
 
 
97
  faiss.normalize_L2(q_emb)
98
- D, I = index.search(q_emb, req.top_k)
99
 
100
  retrieved = [metadata["texts"][i] for i in I[0]]
101
  urls = [metadata["metas"][i]["url"] for i in I[0] if "url" in metadata["metas"][i]]
102
 
103
  context = "\n\n".join(retrieved)
104
- prompt = f"Context:\n{context}\n\nQuestion: {req.question}\nAnswer:"
105
- out = gen_pipeline(prompt, max_length=req.max_answer_tokens, do_sample=False)[0]["generated_text"]
106
-
107
- return {
108
- "answer": f"{out} Find out more at {', '.join(urls)}",
109
- "sources": [metadata["metas"][i] for i in I[0]]
110
- }
111
-
112
- # ---- Health check ----
113
- @app.get("/health")
114
- def health():
115
- return {"status": "ok"}
 
 
 
3
  import faiss
4
  import pickle
5
  import numpy as np
6
+ import gradio as gr
 
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
 
 
 
10
  # Config
11
  DATA_DIR = "data"
12
  INDEX_FILE = os.path.join(DATA_DIR, "index.faiss")
 
17
 
18
  os.makedirs(DATA_DIR, exist_ok=True)
19
 
20
+ # Modele
21
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
22
  gen_model_name = "google/flan-t5-small"
23
  tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
 
49
  metadata = pickle.load(f)
50
  return index, metadata
51
 
 
 
52
  def build_index():
53
  if not os.path.exists(JSON_FILE):
54
+ return None, None
55
+
56
  with open(JSON_FILE, "r", encoding="utf-8") as f:
57
  articles = json.load(f)
58
 
 
74
  metadata = {"texts": texts, "metas": metas}
75
 
76
  save_index(index, metadata)
77
+ return index, metadata
 
 
 
 
 
 
78
 
79
+ def ask_question(question, top_k=4, max_answer_tokens=256):
 
80
  index, metadata = load_index()
81
  if index is None:
82
+ index, metadata = build_index()
83
+ if index is None:
84
+ return "Error: articles.json not found."
85
+
86
+ q_emb = embed_model.encode([question], convert_to_numpy=True).astype("float32")
87
  faiss.normalize_L2(q_emb)
88
+ D, I = index.search(q_emb, top_k)
89
 
90
  retrieved = [metadata["texts"][i] for i in I[0]]
91
  urls = [metadata["metas"][i]["url"] for i in I[0] if "url" in metadata["metas"][i]]
92
 
93
  context = "\n\n".join(retrieved)
94
+ prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
95
+ out = gen_pipeline(prompt, max_length=max_answer_tokens, do_sample=False)[0]["generated_text"]
96
+
97
+ return f"{out} Find out more at {', '.join(urls)}"
98
+
99
+ # Gradio UI
100
+ iface = gr.Interface(
101
+ fn=ask_question,
102
+ inputs=[gr.Textbox(label="Întrebare")],
103
+ outputs=[gr.Textbox(label="Răspuns")],
104
+ live=False,
105
+ )
106
+
107
+ iface.launch(server_name="0.0.0.0", server_port=7860)