iamsahinemir commited on
Commit
896d762
·
verified ·
1 Parent(s): b9880e7

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -4
inference.py CHANGED
@@ -1,6 +1,7 @@
1
  # inference.py
2
 
3
  import re
 
4
  import pandas as pd
5
  import torch
6
  import faiss
@@ -60,8 +61,14 @@ faiss.normalize_L2(row_embs)
60
  row_idx = faiss.IndexFlatIP(row_embs.shape[1])
61
  row_idx.add(row_embs)
62
 
 
 
 
 
 
63
  # ─────────────────────────────────────────────────────────────────────────────
64
  # 5️⃣ generate_answer: app.py’in çağıracağı fonksiyon
 
65
  def generate_answer(user_question: str) -> str:
66
  # (1) normalize “makine” → “RTF makinesi”
67
  q_norm = re.sub(r"\bmakine\b", "RTF makinesi", user_question, flags=re.IGNORECASE)
@@ -70,7 +77,8 @@ def generate_answer(user_question: str) -> str:
70
  if not re.search(r"\b(makine|titreşim|alarm|rtf)\b", q_norm, flags=re.IGNORECASE):
71
  prompt = SYSTEM_PREFIX + "\n" + f"Soru: {q_norm}\nCevap:"
72
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
73
- out = model.generate(**inputs, max_new_tokens=1000)
 
74
  return tokenizer.decode(out[0], skip_special_tokens=True)
75
 
76
  # (3a) veri‐ilgili ise önce static QA
@@ -84,11 +92,13 @@ def generate_answer(user_question: str) -> str:
84
  date=date
85
  )
86
 
87
- # (3b) fallback
88
  if any(tok in ans for tok in ["Cevap bulunamadı", "Lütfen sorunuzda", "Tam olarak anlayamadım"]):
 
89
  ue = embedder_q.encode([q_norm], convert_to_numpy=True)
90
  faiss.normalize_L2(ue)
91
- D_rows, I_rows = row_idx.search(ue, 5)
 
92
  context = "\n".join(row_texts[i] for i in I_rows[0])
93
 
94
  prompt = (
@@ -99,7 +109,8 @@ def generate_answer(user_question: str) -> str:
99
  "Bu verilere dayanarak cevap verin:"
100
  )
101
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
102
- out = model.generate(**inputs, max_new_tokens=1000)
 
103
  return tokenizer.decode(out[0], skip_special_tokens=True)
104
 
105
  # (3c) static QA cevabı
 
1
  # inference.py
2
 
3
  import re
4
+ import threading
5
  import pandas as pd
6
  import torch
7
  import faiss
 
61
  row_idx = faiss.IndexFlatIP(row_embs.shape[1])
62
  row_idx.add(row_embs)
63
 
64
+ # ─────────────────────────────────────────────────────────────────────────────
65
+ # ⚙️ Thread-safety için kilitler
66
+ faiss_lock = threading.Lock()
67
+ model_lock = threading.Lock()
68
+
69
  # ─────────────────────────────────────────────────────────────────────────────
70
  # 5️⃣ generate_answer: app.py’in çağıracağı fonksiyon
71
+ @torch.inference_mode()
72
  def generate_answer(user_question: str) -> str:
73
  # (1) normalize “makine” → “RTF makinesi”
74
  q_norm = re.sub(r"\bmakine\b", "RTF makinesi", user_question, flags=re.IGNORECASE)
 
77
  if not re.search(r"\b(makine|titreşim|alarm|rtf)\b", q_norm, flags=re.IGNORECASE):
78
  prompt = SYSTEM_PREFIX + "\n" + f"Soru: {q_norm}\nCevap:"
79
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
80
+ with model_lock:
81
+ out = model.generate(**inputs, max_new_tokens=1000)
82
  return tokenizer.decode(out[0], skip_special_tokens=True)
83
 
84
  # (3a) veri‐ilgili ise önce static QA
 
92
  date=date
93
  )
94
 
95
+ # (3b) fallback: static QA başarısızsa dynamic RAG + LLM
96
  if any(tok in ans for tok in ["Cevap bulunamadı", "Lütfen sorunuzda", "Tam olarak anlayamadım"]):
97
+ # FAISS üzerinden ilgili satırları al
98
  ue = embedder_q.encode([q_norm], convert_to_numpy=True)
99
  faiss.normalize_L2(ue)
100
+ with faiss_lock:
101
+ D_rows, I_rows = row_idx.search(ue, 5)
102
  context = "\n".join(row_texts[i] for i in I_rows[0])
103
 
104
  prompt = (
 
109
  "Bu verilere dayanarak cevap verin:"
110
  )
111
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
112
+ with model_lock:
113
+ out = model.generate(**inputs, max_new_tokens=1000)
114
  return tokenizer.decode(out[0], skip_special_tokens=True)
115
 
116
  # (3c) static QA cevabı