Bob-Potato commited on
Commit
bcd3bd1
·
verified ·
1 Parent(s): 8baadf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -35
app.py CHANGED
@@ -1,22 +1,24 @@
1
  #!/usr/bin/env python3
2
  """
3
- Hugging Face Space app for Article Q&A AI.
4
- Robust version supporting JSON with different key capitalizations.
 
5
  """
6
 
7
  import os
8
  import json
9
  import faiss
10
  import numpy as np
11
- import gradio as gr
12
  from sentence_transformers import SentenceTransformer
13
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
14
 
15
  # ---- Config ----
16
  CHUNK_SIZE = 500
17
  CHUNK_OVERLAP = 100
18
- JSON_FILE = "articles.json" # relative to WORKDIR
19
- TOP_K = 4
20
  SERVER_PORT = 7860
21
 
22
  # ---- Global variables ----
@@ -26,7 +28,7 @@ INDEX_DIM = None
26
 
27
  # ---- Models ----
28
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
29
- gen_model_name = "google/flan-t5-small"
30
  tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
31
  gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
32
  gen_pipeline = pipeline(
@@ -48,17 +50,10 @@ def chunk_text(text):
48
  return chunks
49
 
50
  def build_index_in_memory():
51
- """Build FAISS index in memory and return index, metadata, dim"""
52
  print("🚀 Building FAISS index...")
53
- print("Current WORKDIR:", os.getcwd())
54
- print("Files in WORKDIR:", os.listdir("."))
55
- print("Looking for articles.json:", JSON_FILE)
56
- print("Exists?", os.path.exists(JSON_FILE))
57
-
58
  if not os.path.exists(JSON_FILE):
59
  print("❌ articles.json missing")
60
  return None, None, None
61
-
62
  try:
63
  with open(JSON_FILE, "r", encoding="utf-8") as f:
64
  articles = json.load(f)
@@ -67,24 +62,20 @@ def build_index_in_memory():
67
  return None, None, None
68
 
69
  if not articles:
70
- print("❌ articles.json is empty")
71
  return None, None, None
72
 
73
  embeddings_list, texts, metas = [], [], []
74
 
75
  for art_id, art in enumerate(articles):
76
- # Support both lowercase and capitalized keys
77
  content = art.get("Continut") or art.get("continut") or ""
78
  url = art.get("URL") or art.get("url") or ""
79
  title = art.get("Titlu") or art.get("titlu") or f"articol_{art_id}"
80
-
81
  if not content.strip():
82
  continue
83
-
84
  chunks = chunk_text(content)
85
- if len(chunks) == 0:
86
  continue
87
-
88
  embs = embed_model.encode(chunks, convert_to_numpy=True)
89
  if embs.ndim == 1:
90
  embs = embs.reshape(1, -1)
@@ -92,7 +83,7 @@ def build_index_in_memory():
92
  texts.extend(chunks)
93
  metas.extend([{"source": title, "url": url, "chunk_id": i} for i in range(len(chunks))])
94
 
95
- if len(embeddings_list) == 0:
96
  print("❌ No valid chunks found")
97
  return None, None, None
98
 
@@ -101,14 +92,12 @@ def build_index_in_memory():
101
  d = embeddings.shape[1]
102
  index = faiss.IndexFlatIP(d)
103
  index.add(embeddings)
104
-
105
  metadata = {"texts": texts, "metas": metas}
106
  print(f"✅ Index built with {len(texts)} chunks")
107
  return index, metadata, d
108
 
109
- def ask_question(question, top_k=TOP_K, max_answer_tokens=256):
110
  global INDEX, METADATA, INDEX_DIM
111
-
112
  if not question.strip():
113
  return "⚠️ Please provide a question."
114
 
@@ -121,7 +110,6 @@ def ask_question(question, top_k=TOP_K, max_answer_tokens=256):
121
  if q_emb.ndim == 1:
122
  q_emb = q_emb.reshape(1, -1)
123
 
124
- # Rebuild index if embedding dimension mismatch
125
  if INDEX_DIM is None or q_emb.shape[1] != INDEX_DIM:
126
  INDEX, METADATA, INDEX_DIM = build_index_in_memory()
127
  if INDEX is None or q_emb.shape[1] != INDEX_DIM:
@@ -154,18 +142,17 @@ def ask_question(question, top_k=TOP_K, max_answer_tokens=256):
154
 
155
  return f"{out} Find out more at {', '.join([u for u in urls if u])}"
156
 
157
- def main():
158
- print("🚀 Starting Article Q&A AI...")
159
- print(f"📁 Looking for articles.json at {JSON_FILE}")
160
 
161
- iface = gr.Interface(
162
- fn=ask_question,
163
- inputs=[gr.Textbox(label="Întrebare")],
164
- outputs=[gr.Textbox(label="Răspuns")],
165
- live=False,
166
- )
167
 
168
- iface.launch(server_name="0.0.0.0", server_port=SERVER_PORT)
 
 
169
 
170
  if __name__ == "__main__":
171
- main()
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ HF Space API for Article Q&A AI.
4
+ Optimized for CPU / Free Tier.
5
+ Uses tiny-flan-t5 for faster generation.
6
  """
7
 
8
  import os
9
  import json
10
  import faiss
11
  import numpy as np
 
12
  from sentence_transformers import SentenceTransformer
13
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
14
+ from fastapi import FastAPI
15
+ from pydantic import BaseModel
16
 
17
  # ---- Config ----
18
  CHUNK_SIZE = 500
19
  CHUNK_OVERLAP = 100
20
+ JSON_FILE = "articles.json"
21
+ TOP_K = 3 # fewer chunks for speed
22
  SERVER_PORT = 7860
23
 
24
  # ---- Global variables ----
 
28
 
29
  # ---- Models ----
30
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
31
+ gen_model_name = "sshleifer/tiny-flan-t5"
32
  tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
33
  gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
34
  gen_pipeline = pipeline(
 
50
  return chunks
51
 
52
  def build_index_in_memory():
 
53
  print("🚀 Building FAISS index...")
 
 
 
 
 
54
  if not os.path.exists(JSON_FILE):
55
  print("❌ articles.json missing")
56
  return None, None, None
 
57
  try:
58
  with open(JSON_FILE, "r", encoding="utf-8") as f:
59
  articles = json.load(f)
 
62
  return None, None, None
63
 
64
  if not articles:
65
+ print("❌ articles.json empty")
66
  return None, None, None
67
 
68
  embeddings_list, texts, metas = [], [], []
69
 
70
  for art_id, art in enumerate(articles):
 
71
  content = art.get("Continut") or art.get("continut") or ""
72
  url = art.get("URL") or art.get("url") or ""
73
  title = art.get("Titlu") or art.get("titlu") or f"articol_{art_id}"
 
74
  if not content.strip():
75
  continue
 
76
  chunks = chunk_text(content)
77
+ if not chunks:
78
  continue
 
79
  embs = embed_model.encode(chunks, convert_to_numpy=True)
80
  if embs.ndim == 1:
81
  embs = embs.reshape(1, -1)
 
83
  texts.extend(chunks)
84
  metas.extend([{"source": title, "url": url, "chunk_id": i} for i in range(len(chunks))])
85
 
86
+ if not embeddings_list:
87
  print("❌ No valid chunks found")
88
  return None, None, None
89
 
 
92
  d = embeddings.shape[1]
93
  index = faiss.IndexFlatIP(d)
94
  index.add(embeddings)
 
95
  metadata = {"texts": texts, "metas": metas}
96
  print(f"✅ Index built with {len(texts)} chunks")
97
  return index, metadata, d
98
 
99
+ def ask_question(question, top_k=TOP_K, max_answer_tokens=64):
100
  global INDEX, METADATA, INDEX_DIM
 
101
  if not question.strip():
102
  return "⚠️ Please provide a question."
103
 
 
110
  if q_emb.ndim == 1:
111
  q_emb = q_emb.reshape(1, -1)
112
 
 
113
  if INDEX_DIM is None or q_emb.shape[1] != INDEX_DIM:
114
  INDEX, METADATA, INDEX_DIM = build_index_in_memory()
115
  if INDEX is None or q_emb.shape[1] != INDEX_DIM:
 
142
 
143
  return f"{out} Find out more at {', '.join([u for u in urls if u])}"
144
 
145
+ # ---- FastAPI ----
146
+ app = FastAPI()
 
147
 
148
+ class Question(BaseModel):
149
+ text: str
 
 
 
 
150
 
151
+ @app.post("/ask")
152
+ def ask(q: Question):
153
+ return {"answer": ask_question(q.text)}
154
 
155
  if __name__ == "__main__":
156
+ import uvicorn
157
+ INDEX, METADATA, INDEX_DIM = build_index_in_memory()
158
+ uvicorn.run(app, host="0.0.0.0", port=SERVER_PORT)