Bob-Potato commited on
Commit
2ffe70e
·
verified ·
1 Parent(s): 9c537e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -30
app.py CHANGED
@@ -1,8 +1,10 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
  import os
 
3
  import faiss
4
  import pickle
5
  import numpy as np
 
 
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
 
@@ -10,13 +12,13 @@ app = FastAPI(title="MetaGPT AI - Local Q&A")
10
 
11
  # Config
12
  DATA_DIR = "data"
13
- DOCS_DIR = os.path.join(DATA_DIR, "docs")
14
  INDEX_FILE = os.path.join(DATA_DIR, "index.faiss")
15
  METADATA_FILE = os.path.join(DATA_DIR, "metadata.pkl")
16
  CHUNK_SIZE = 500
17
  CHUNK_OVERLAP = 100
 
18
 
19
- os.makedirs(DOCS_DIR, exist_ok=True)
20
 
21
  # Load models
22
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
@@ -25,7 +27,7 @@ tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
25
  gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
26
  gen_pipeline = pipeline("text2text-generation", model=gen_model, tokenizer=tokenizer, device=-1)
27
 
28
- # Helper: chunk text
29
  def chunk_text(text):
30
  chunks = []
31
  start = 0
@@ -37,7 +39,6 @@ def chunk_text(text):
37
  if start >= len(text): break
38
  return chunks
39
 
40
- # Helper: save/load FAISS index
41
  def save_index(index, metadata):
42
  faiss.write_index(index, INDEX_FILE)
43
  with open(METADATA_FILE, "wb") as f:
@@ -51,58 +52,64 @@ def load_index():
51
  metadata = pickle.load(f)
52
  return index, metadata
53
 
54
- # Endpoint: upload document
55
- @app.post("/upload")
56
- async def upload(files: list[UploadFile] = File(...)):
57
- index, metadata = load_index()
58
- embeddings, metas, texts = [], [], []
 
 
 
59
 
60
- for up in files:
61
- content = (await up.read()).decode("utf-8")
 
 
 
62
  chunks = chunk_text(content)
63
  embs = embed_model.encode(chunks, convert_to_numpy=True)
64
  embeddings.append(embs)
65
  texts.extend(chunks)
66
- metas.extend([{"source": up.filename, "chunk_id": i} for i in range(len(chunks))])
67
 
68
  embeddings = np.vstack(embeddings).astype("float32")
69
- if index is None:
70
- index = faiss.IndexFlatIP(embeddings.shape[1])
71
- faiss.normalize_L2(embeddings)
72
- index.add(embeddings)
73
- metadata = {"texts": texts, "metas": metas}
74
- else:
75
- faiss.normalize_L2(embeddings)
76
- index.add(embeddings)
77
- metadata["texts"].extend(texts)
78
- metadata["metas"].extend(metas)
79
 
80
  save_index(index, metadata)
81
- return {"added_chunks": embeddings.shape[0], "total_chunks": len(metadata["texts"])}
82
 
83
- # Endpoint: ask question
84
- from pydantic import BaseModel
85
  class AskRequest(BaseModel):
86
  question: str
87
  top_k: int = 4
88
  max_answer_tokens: int = 256
89
 
90
- from fastapi import Depends
91
  @app.post("/ask")
92
  def ask(req: AskRequest):
93
  index, metadata = load_index()
94
  if index is None:
95
- raise HTTPException(status_code=404, detail="No index found. Upload docs first.")
 
96
  q_emb = embed_model.encode([req.question], convert_to_numpy=True).astype("float32")
97
  faiss.normalize_L2(q_emb)
98
  D, I = index.search(q_emb, req.top_k)
 
99
  retrieved = [metadata["texts"][i] for i in I[0]]
 
 
100
  context = "\n\n".join(retrieved)
101
  prompt = f"Context:\n{context}\n\nQuestion: {req.question}\nAnswer:"
102
  out = gen_pipeline(prompt, max_length=req.max_answer_tokens, do_sample=False)[0]["generated_text"]
103
- return {"answer": out, "sources": [metadata["metas"][i] for i in I[0]]}
104
 
105
- # Health check
 
 
 
 
 
106
  @app.get("/health")
107
  def health():
108
  return {"status": "ok"}
 
 
1
  import os
2
+ import json
3
  import faiss
4
  import pickle
5
  import numpy as np
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
  from sentence_transformers import SentenceTransformer
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
 
 
12
 
13
  # Config
14
  DATA_DIR = "data"
 
15
  INDEX_FILE = os.path.join(DATA_DIR, "index.faiss")
16
  METADATA_FILE = os.path.join(DATA_DIR, "metadata.pkl")
17
  CHUNK_SIZE = 500
18
  CHUNK_OVERLAP = 100
19
+ JSON_FILE = "articles.json"
20
 
21
+ os.makedirs(DATA_DIR, exist_ok=True)
22
 
23
  # Load models
24
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
27
  gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
28
  gen_pipeline = pipeline("text2text-generation", model=gen_model, tokenizer=tokenizer, device=-1)
29
 
30
+ # Helpers
31
  def chunk_text(text):
32
  chunks = []
33
  start = 0
 
39
  if start >= len(text): break
40
  return chunks
41
 
 
42
  def save_index(index, metadata):
43
  faiss.write_index(index, INDEX_FILE)
44
  with open(METADATA_FILE, "wb") as f:
 
52
  metadata = pickle.load(f)
53
  return index, metadata
54
 
55
+ # ---- Build / Rebuild index from JSON ----
56
+ @app.post("/build_index")
57
+ def build_index():
58
+ if not os.path.exists(JSON_FILE):
59
+ raise HTTPException(status_code=404, detail=f"{JSON_FILE} not found")
60
+
61
+ with open(JSON_FILE, "r", encoding="utf-8") as f:
62
+ articles = json.load(f)
63
 
64
+ embeddings, texts, metas = [], [], []
65
+
66
+ for art_id, art in enumerate(articles):
67
+ content = art.get("Continut", "")
68
+ url = art.get("URL", "")
69
  chunks = chunk_text(content)
70
  embs = embed_model.encode(chunks, convert_to_numpy=True)
71
  embeddings.append(embs)
72
  texts.extend(chunks)
73
+ metas.extend([{"source": art.get("Titlu", f"articol_{art_id}"), "url": url, "chunk_id": i} for i in range(len(chunks))])
74
 
75
  embeddings = np.vstack(embeddings).astype("float32")
76
+ faiss.normalize_L2(embeddings)
77
+ index = faiss.IndexFlatIP(embeddings.shape[1])
78
+ index.add(embeddings)
79
+ metadata = {"texts": texts, "metas": metas}
 
 
 
 
 
 
80
 
81
  save_index(index, metadata)
82
+ return {"status": "ok", "total_chunks": len(texts)}
83
 
84
+ # ---- Ask endpoint ----
 
85
  class AskRequest(BaseModel):
86
  question: str
87
  top_k: int = 4
88
  max_answer_tokens: int = 256
89
 
 
90
  @app.post("/ask")
91
  def ask(req: AskRequest):
92
  index, metadata = load_index()
93
  if index is None:
94
+ raise HTTPException(status_code=404, detail="No index found. Call /build_index first.")
95
+
96
  q_emb = embed_model.encode([req.question], convert_to_numpy=True).astype("float32")
97
  faiss.normalize_L2(q_emb)
98
  D, I = index.search(q_emb, req.top_k)
99
+
100
  retrieved = [metadata["texts"][i] for i in I[0]]
101
+ urls = [metadata["metas"][i]["url"] for i in I[0] if "url" in metadata["metas"][i]]
102
+
103
  context = "\n\n".join(retrieved)
104
  prompt = f"Context:\n{context}\n\nQuestion: {req.question}\nAnswer:"
105
  out = gen_pipeline(prompt, max_length=req.max_answer_tokens, do_sample=False)[0]["generated_text"]
 
106
 
107
+ return {
108
+ "answer": f"{out} Find out more at {', '.join(urls)}",
109
+ "sources": [metadata["metas"][i] for i in I[0]]
110
+ }
111
+
112
+ # ---- Health check ----
113
  @app.get("/health")
114
  def health():
115
  return {"status": "ok"}