Bob-Potato commited on
Commit
a46ce71
·
verified ·
1 Parent(s): bcd3bd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -32
app.py CHANGED
@@ -1,24 +1,22 @@
1
  #!/usr/bin/env python3
2
  """
3
- HF Space API for Article Q&A AI.
4
- Optimized for CPU / Free Tier.
5
- Uses tiny-flan-t5 for faster generation.
6
  """
7
 
8
  import os
9
  import json
10
- import faiss
11
  import numpy as np
 
 
12
  from sentence_transformers import SentenceTransformer
13
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
14
- from fastapi import FastAPI
15
- from pydantic import BaseModel
16
 
17
  # ---- Config ----
18
  CHUNK_SIZE = 500
19
  CHUNK_OVERLAP = 100
20
  JSON_FILE = "articles.json"
21
- TOP_K = 3 # fewer chunks for speed
22
  SERVER_PORT = 7860
23
 
24
  # ---- Global variables ----
@@ -28,12 +26,10 @@ INDEX_DIM = None
28
 
29
  # ---- Models ----
30
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
31
- gen_model_name = "sshleifer/tiny-flan-t5"
32
  tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
33
  gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
34
- gen_pipeline = pipeline(
35
- "text2text-generation", model=gen_model, tokenizer=tokenizer, device=-1
36
- )
37
 
38
  # ---- Helpers ----
39
  def chunk_text(text):
@@ -43,17 +39,19 @@ def chunk_text(text):
43
  end = min(start + CHUNK_SIZE, len(text))
44
  chunks.append(text[start:end])
45
  start = end - CHUNK_OVERLAP
46
- if start < 0:
47
- start = 0
48
- if start >= len(text):
49
- break
50
  return chunks
51
 
52
  def build_index_in_memory():
53
  print("🚀 Building FAISS index...")
 
 
 
54
  if not os.path.exists(JSON_FILE):
55
  print("❌ articles.json missing")
56
  return None, None, None
 
57
  try:
58
  with open(JSON_FILE, "r", encoding="utf-8") as f:
59
  articles = json.load(f)
@@ -68,9 +66,8 @@ def build_index_in_memory():
68
  embeddings_list, texts, metas = [], [], []
69
 
70
  for art_id, art in enumerate(articles):
71
- content = art.get("Continut") or art.get("continut") or ""
72
- url = art.get("URL") or art.get("url") or ""
73
- title = art.get("Titlu") or art.get("titlu") or f"articol_{art_id}"
74
  if not content.strip():
75
  continue
76
  chunks = chunk_text(content)
@@ -81,7 +78,7 @@ def build_index_in_memory():
81
  embs = embs.reshape(1, -1)
82
  embeddings_list.append(embs)
83
  texts.extend(chunks)
84
- metas.extend([{"source": title, "url": url, "chunk_id": i} for i in range(len(chunks))])
85
 
86
  if not embeddings_list:
87
  print("❌ No valid chunks found")
@@ -92,12 +89,14 @@ def build_index_in_memory():
92
  d = embeddings.shape[1]
93
  index = faiss.IndexFlatIP(d)
94
  index.add(embeddings)
 
95
  metadata = {"texts": texts, "metas": metas}
96
  print(f"✅ Index built with {len(texts)} chunks")
97
  return index, metadata, d
98
 
99
- def ask_question(question, top_k=TOP_K, max_answer_tokens=64):
100
  global INDEX, METADATA, INDEX_DIM
 
101
  if not question.strip():
102
  return "⚠️ Please provide a question."
103
 
@@ -142,17 +141,15 @@ def ask_question(question, top_k=TOP_K, max_answer_tokens=64):
142
 
143
  return f"{out} Find out more at {', '.join([u for u in urls if u])}"
144
 
145
- # ---- FastAPI ----
146
- app = FastAPI()
147
-
148
- class Question(BaseModel):
149
- text: str
150
-
151
- @app.post("/ask")
152
- def ask(q: Question):
153
- return {"answer": ask_question(q.text)}
154
 
155
  if __name__ == "__main__":
156
- import uvicorn
157
- INDEX, METADATA, INDEX_DIM = build_index_in_memory()
158
- uvicorn.run(app, host="0.0.0.0", port=SERVER_PORT)
 
1
  #!/usr/bin/env python3
2
  """
3
+ Hugging Face Space app: Article Q&A AI
4
+ Simplified, CPU-friendly, public model (google/flan-t5-small)
 
5
  """
6
 
7
  import os
8
  import json
 
9
  import numpy as np
10
+ import faiss
11
+ import gradio as gr
12
  from sentence_transformers import SentenceTransformer
13
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
 
14
 
15
  # ---- Config ----
16
  CHUNK_SIZE = 500
17
  CHUNK_OVERLAP = 100
18
  JSON_FILE = "articles.json"
19
+ TOP_K = 4
20
  SERVER_PORT = 7860
21
 
22
  # ---- Global variables ----
 
26
 
27
  # ---- Models ----
28
  embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
29
+ gen_model_name = "google/flan-t5-small"
30
  tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
31
  gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
32
+ gen_pipeline = pipeline("text2text-generation", model=gen_model, tokenizer=tokenizer, device=-1)
 
 
33
 
34
  # ---- Helpers ----
35
  def chunk_text(text):
 
39
  end = min(start + CHUNK_SIZE, len(text))
40
  chunks.append(text[start:end])
41
  start = end - CHUNK_OVERLAP
42
+ if start < 0: start = 0
43
+ if start >= len(text): break
 
 
44
  return chunks
45
 
46
  def build_index_in_memory():
47
  print("🚀 Building FAISS index...")
48
+ print("Current WORKDIR:", os.getcwd())
49
+ print("Files:", os.listdir("."))
50
+
51
  if not os.path.exists(JSON_FILE):
52
  print("❌ articles.json missing")
53
  return None, None, None
54
+
55
  try:
56
  with open(JSON_FILE, "r", encoding="utf-8") as f:
57
  articles = json.load(f)
 
66
  embeddings_list, texts, metas = [], [], []
67
 
68
  for art_id, art in enumerate(articles):
69
+ content = art.get("continut") or art.get("Continut") or ""
70
+ url = art.get("url") or art.get("URL") or ""
 
71
  if not content.strip():
72
  continue
73
  chunks = chunk_text(content)
 
78
  embs = embs.reshape(1, -1)
79
  embeddings_list.append(embs)
80
  texts.extend(chunks)
81
+ metas.extend([{"source": art.get("titlu") or art.get("Titlu") or f"articol_{art_id}", "url": url, "chunk_id": i} for i in range(len(chunks))])
82
 
83
  if not embeddings_list:
84
  print("❌ No valid chunks found")
 
89
  d = embeddings.shape[1]
90
  index = faiss.IndexFlatIP(d)
91
  index.add(embeddings)
92
+
93
  metadata = {"texts": texts, "metas": metas}
94
  print(f"✅ Index built with {len(texts)} chunks")
95
  return index, metadata, d
96
 
97
+ def ask_question(question, top_k=TOP_K, max_answer_tokens=256):
98
  global INDEX, METADATA, INDEX_DIM
99
+
100
  if not question.strip():
101
  return "⚠️ Please provide a question."
102
 
 
141
 
142
  return f"{out} Find out more at {', '.join([u for u in urls if u])}"
143
 
144
+ def main():
145
+ print("🚀 Starting Article Q&A AI...")
146
+ iface = gr.Interface(
147
+ fn=ask_question,
148
+ inputs=[gr.Textbox(label="Întrebare")],
149
+ outputs=[gr.Textbox(label="Răspuns")],
150
+ live=False,
151
+ )
152
+ iface.launch(server_name="0.0.0.0", server_port=SERVER_PORT)
153
 
154
  if __name__ == "__main__":
155
+ main()