Bob-Potato commited on
Commit
5e8b2a9
·
verified ·
1 Parent(s): 6e899e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -1,27 +1,32 @@
 
 
 
1
  from fastapi import FastAPI, Request
2
  import json
3
  import faiss
4
  from sentence_transformers import SentenceTransformer
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
 
8
  app = FastAPI()
9
 
10
  # ---------------------------
11
- # 1. Încarcă modelul Gemma 1B
12
  # ---------------------------
13
- MODEL_NAME = "distilgpt2" # modelul mic Gemma 1B
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
16
 
17
  # ---------------------------
18
  # 2. Încarcă articolele și embeddings
19
  # ---------------------------
20
- with open("articles.json", "r") as f:
21
  articles = json.load(f)
22
 
 
23
  sentences = [a["content"] for a in articles]
24
 
 
25
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
26
  embeddings = embedder.encode(sentences)
27
  index = faiss.IndexFlatL2(embeddings.shape[1])
@@ -34,7 +39,7 @@ index.add(embeddings)
34
  async def ask(request: Request):
35
  data = await request.json()
36
  question = data.get("question", "")
37
-
38
  # căutare semantică
39
  q_emb = embedder.encode([question])
40
  D, I = index.search(q_emb, k=3)
@@ -46,7 +51,7 @@ async def ask(request: Request):
46
  inputs = tokenizer(prompt, return_tensors="pt")
47
  outputs = model.generate(**inputs, max_new_tokens=150)
48
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
-
50
  return {"answer": answer}
51
 
52
  # ---------------------------
 
1
+ import os
2
+ os.environ["HF_HOME"] = "/tmp/hf" # cache scriibil în Space
3
+
4
  from fastapi import FastAPI, Request
5
  import json
6
  import faiss
7
  from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
  import torch
10
 
11
  app = FastAPI()
12
 
13
  # ---------------------------
14
+ # 1. Încarcă modelul
15
  # ---------------------------
16
+ MODEL_NAME = "google/flan-t5-small" # public și mic
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
19
 
20
  # ---------------------------
21
  # 2. Încarcă articolele și embeddings
22
  # ---------------------------
23
+ with open("articles.json", "r", encoding="utf-8") as f:
24
  articles = json.load(f)
25
 
26
+ # fiecare articol -> text
27
  sentences = [a["content"] for a in articles]
28
 
29
+ # embeddings rapide
30
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
31
  embeddings = embedder.encode(sentences)
32
  index = faiss.IndexFlatL2(embeddings.shape[1])
 
39
  async def ask(request: Request):
40
  data = await request.json()
41
  question = data.get("question", "")
42
+
43
  # căutare semantică
44
  q_emb = embedder.encode([question])
45
  D, I = index.search(q_emb, k=3)
 
51
  inputs = tokenizer(prompt, return_tensors="pt")
52
  outputs = model.generate(**inputs, max_new_tokens=150)
53
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
+
55
  return {"answer": answer}
56
 
57
  # ---------------------------