Bob-Potato commited on
Commit
f78ffb5
·
verified ·
1 Parent(s): 48d4c18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -35
app.py CHANGED
@@ -1,40 +1,57 @@
1
- import gradio as gr
2
- import os
3
- import requests
 
 
 
4
 
5
- # Model public gratuit pentru HF Inference API
6
- MODEL_ID = "distilbert-base-uncased"
7
 
8
- HF_API_TOKEN = os.getenv("HF_API_TOKEN") # poate fi gol pentru public
9
- HF_API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
10
- HEADERS = {"Authorization": f"Bearer {HF_API_TOKEN}"} if HF_API_TOKEN else {}
 
 
 
11
 
12
- def ask_ai(question):
13
- if not question.strip():
14
- return "Te rog scrie o întrebare."
15
-
16
- payload = {
17
- "inputs": question,
18
- "parameters": {"max_new_tokens": 150, "return_full_text": False}
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- try:
22
- r = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=30)
23
- r.raise_for_status()
24
- except Exception as e:
25
- return f"Eroare la apel HF Inference API: {str(e)}"
 
 
 
 
 
 
26
 
27
- res = r.json()
28
- if isinstance(res, list) and "generated_text" in res[0]:
29
- return res[0]["generated_text"]
30
- return str(res)
31
-
32
- iface = gr.Interface(
33
- fn=ask_ai,
34
- inputs=gr.Textbox(lines=2, placeholder="Întreabă ceva..."),
35
- outputs="text",
36
- title="Chatbot HF API",
37
- description="Chatbot funcțional fără descărcarea modelelor, rulând HF Inference API."
38
- )
39
-
40
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ from fastapi import FastAPI, Request
2
+ import json
3
+ import faiss
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import torch
7
 
8
+ app = FastAPI()
 
9
 
10
+ # ---------------------------
11
+ # 1. Încarcă modelul Gemma 1B
12
+ # ---------------------------
13
+ MODEL_NAME = "google/gemma-3-1b-it" # modelul mic Gemma 1B
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
16
 
17
+ # ---------------------------
18
+ # 2. Încarcă articolele și embeddings
19
+ # ---------------------------
20
+ with open("articles.json", "r") as f:
21
+ articles = json.load(f)
22
+
23
+ sentences = [a["content"] for a in articles]
24
+
25
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
26
+ embeddings = embedder.encode(sentences)
27
+ index = faiss.IndexFlatL2(embeddings.shape[1])
28
+ index.add(embeddings)
29
+
30
+ # ---------------------------
31
+ # 3. Endpoint pentru întrebări
32
+ # ---------------------------
33
+ @app.post("/ask")
34
+ async def ask(request: Request):
35
+ data = await request.json()
36
+ question = data.get("question", "")
37
 
38
+ # căutare semantică
39
+ q_emb = embedder.encode([question])
40
+ D, I = index.search(q_emb, k=3)
41
+ context = " ".join([sentences[i] for i in I[0]])
42
+
43
+ # prompt pentru model
44
+ prompt = f"Context: {context}\nÎntrebare: {question}\nRăspuns:"
45
+
46
+ inputs = tokenizer(prompt, return_tensors="pt")
47
+ outputs = model.generate(**inputs, max_new_tokens=150)
48
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
 
50
+ return {"answer": answer}
51
+
52
+ # ---------------------------
53
+ # 4. Run server
54
+ # ---------------------------
55
+ if __name__ == "__main__":
56
+ import uvicorn
57
+ uvicorn.run(app, host="0.0.0.0", port=7860)