Bob-Potato commited on
Commit
dc7acf2
·
verified ·
1 Parent(s): 615f35e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -109
app.py CHANGED
@@ -1,109 +1,34 @@
1
- # app.py
2
- import os
3
- import json
4
- import faiss
5
- import numpy as np
6
- import requests
7
- from fastapi import FastAPI, HTTPException
8
- from pydantic import BaseModel
9
- from sentence_transformers import SentenceTransformer
10
-
11
- # config
12
- JSON_FILE = "articles.json"
13
- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
14
- TOP_K = 3
15
- HF_MODEL = os.getenv("HF_MODEL", "google/flan-t5-small") # model hosted on HF
16
- HF_TOKEN = os.getenv("HF_API_TOKEN") # set in Secrets on HF Spaces
17
- SERVER_PORT = int(os.getenv("PORT", 7860))
18
-
19
- # load embeddings
20
- embed_model = SentenceTransformer(EMBED_MODEL)
21
-
22
- # build index
23
- if not os.path.exists(JSON_FILE):
24
- raise FileNotFoundError("articles.json not found")
25
- with open(JSON_FILE, "r", encoding="utf-8") as f:
26
- articles = json.load(f)
27
-
28
- def chunk_text(text, size=500, overlap=100):
29
- chunks=[]
30
- s=0
31
- while s < len(text):
32
- e=min(s+size, len(text))
33
- chunks.append(text[s:e])
34
- s=e-overlap
35
- if s<0: s=0
36
- if s>=len(text): break
37
- return chunks
38
-
39
- texts=[]; metas=[]; embs_list=[]
40
- for i,art in enumerate(articles):
41
- content = art.get("continut") or art.get("Continut") or ""
42
- if not content.strip(): continue
43
- url = art.get("url") or art.get("URL") or ""
44
- title = art.get("titlu") or art.get("Titlu") or f"art_{i}"
45
- chunks = chunk_text(content)
46
- if not chunks: continue
47
- embs = embed_model.encode(chunks, convert_to_numpy=True)
48
- if embs.ndim==1: embs = embs.reshape(1,-1)
49
- embs_list.append(embs)
50
- texts.extend(chunks)
51
- metas.extend([{"title":title,"url":url,"chunk":j} for j in range(len(chunks))])
52
-
53
- if len(embs_list)==0:
54
- raise ValueError("No valid chunks in articles.json")
55
-
56
- embeddings = np.vstack(embs_list).astype("float32")
57
- faiss.normalize_L2(embeddings)
58
- index = faiss.IndexFlatIP(embeddings.shape[1])
59
- index.add(embeddings)
60
- metadata={"texts":texts,"metas":metas}
61
-
62
- # HF generation helper
63
- def generate_via_hf(prompt, max_tokens=128):
64
- if not HF_TOKEN:
65
- raise RuntimeError("HF_API_TOKEN not set in env")
66
- url = f"https://api-inference.huggingface.co/models/{HF_MODEL}"
67
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
68
- payload = {"inputs": prompt, "parameters": {"max_new_tokens": max_tokens, "do_sample": False}}
69
- r = requests.post(url, headers=headers, json=payload, timeout=60)
70
- r.raise_for_status()
71
- data = r.json()
72
- # handle expected response
73
- if isinstance(data, list) and "generated_text" in data[0]:
74
- return data[0]["generated_text"]
75
- if isinstance(data, dict) and "error" in data:
76
- raise RuntimeError("HF error: " + data["error"])
77
- return str(data)
78
-
79
- # FastAPI
80
- app = FastAPI()
81
-
82
- class Q(BaseModel):
83
- question: str
84
-
85
- @app.get("/ping")
86
- def ping():
87
- return {"status":"ok"}
88
-
89
- @app.post("/ask")
90
- def ask(q: Q):
91
- qtext = q.question.strip()
92
- if not qtext:
93
- raise HTTPException(status_code=400, detail="Empty question")
94
- q_emb = embed_model.encode([qtext], convert_to_numpy=True).astype("float32")
95
- if q_emb.ndim==1: q_emb = q_emb.reshape(1,-1)
96
- faiss.normalize_L2(q_emb)
97
- k = min(TOP_K, index.ntotal)
98
- if k<=0:
99
- return {"answer":"No articles indexed."}
100
- D,I = index.search(q_emb, k)
101
- retrieved = [metadata["texts"][i] for i in I[0]]
102
- urls = [metadata["metas"][i].get("url","") for i in I[0]]
103
- context = "\n\n".join(retrieved)
104
- prompt = f"Context:\n{context}\n\nQuestion: {qtext}\nAnswer:"
105
- try:
106
- generated = generate_via_hf(prompt, max_tokens=128)
107
- except Exception as e:
108
- return {"answer": f"HF generation error: {e}", "sources": urls}
109
- return {"answer": generated, "sources": [u for u in urls if u]}
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Model public, mic și gratuit
6
+ MODEL_NAME = "google/flan-t5-small"
7
+
8
+ # Încarcă model și tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model = model.to(device)
14
+
15
+ # Funcția chatbot
16
+ def chat_fn(question):
17
+ if not question.strip():
18
+ return "Te rog scrie o întrebare."
19
+
20
+ inputs = tokenizer(question, return_tensors="pt").to(device)
21
+ outputs = model.generate(**inputs, max_new_tokens=150)
22
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
+ return answer
24
+
25
+ # Gradio UI
26
+ iface = gr.Interface(
27
+ fn=chat_fn,
28
+ inputs=gr.Textbox(lines=2, placeholder="Întreabă ceva..."),
29
+ outputs="text",
30
+ title="Chatbot simplu",
31
+ description="Chatbot minimal bazat pe Flan-T5-small (fără date pre-trained locale)."
32
+ )
33
+
34
+ iface.launch(server_name="0.0.0.0", server_port=7860)