janajankovic commited on
Commit
d6ba2e2
·
verified ·
1 Parent(s): 8c4bae4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -51
app.py CHANGED
@@ -13,14 +13,18 @@ from sklearn.metrics.pairwise import cosine_similarity
13
  # CONFIG – EDIT THESE TWO LINES TO MATCH YOUR REPOS
14
  # ------------------------------------------------------------------
15
  BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "cjvt/GaMS-1B-Chat")
16
- # Replace this with the name of YOUR fine-tuned adapter repo
17
  ADAPTER_ID = os.getenv("ADAPTER_ID", "janajankovic/autotrain-juhh6-uwiv9")
18
 
19
  CSV_PATH = "chunks_for_autotrain.csv"
20
- TOP_K = 4 # how many most similar chunks to use as context
21
  MAX_INPUT_LEN = 2048
22
  MAX_NEW_TOKENS = 256
23
 
 
 
 
 
 
24
 
25
  # ------------------------------------------------------------------
26
  # LOAD CSV CHUNKS + TF-IDF INDEX
@@ -30,13 +34,11 @@ if not os.path.exists(CSV_PATH):
30
 
31
  df = pd.read_csv(CSV_PATH)
32
 
33
- # Try to guess which column holds the text
34
  if "chunk" in df.columns:
35
  text_col = "chunk"
36
  elif "text" in df.columns:
37
  text_col = "text"
38
  else:
39
- # fallback: first column
40
  text_col = df.columns[0]
41
 
42
  chunks = df[text_col].astype(str).tolist()
@@ -57,13 +59,16 @@ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
57
  if tokenizer.pad_token is None:
58
  tokenizer.pad_token = tokenizer.eos_token
59
 
 
 
 
 
60
  base_model = AutoModelForCausalLM.from_pretrained(
61
  BASE_MODEL_ID,
62
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
63
  )
64
 
65
  model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
66
- # Merge LoRA into the base model so we can use it like a normal CausalLM
67
  model = model.merge_and_unload()
68
  model.to(device)
69
  model.eval()
@@ -80,7 +85,6 @@ SYSTEM_PROMPT = (
80
 
81
 
82
  def retrieve_chunks(question: str, top_k: int = TOP_K):
83
- """Return top_k most similar chunks for the given question."""
84
  q_vec = vectorizer.transform([question])
85
  sims = cosine_similarity(q_vec, tfidf_matrix)[0]
86
  top_idx = sims.argsort()[::-1][:top_k]
@@ -89,7 +93,6 @@ def retrieve_chunks(question: str, top_k: int = TOP_K):
89
 
90
  def build_prompt(question: str, retrieved):
91
  context = "\n\n---\n\n".join(retrieved)
92
-
93
  prompt = (
94
  f"{SYSTEM_PROMPT}\n\n"
95
  f"Kontekst:\n{context}\n\n"
@@ -106,13 +109,9 @@ def build_prompt(question: str, retrieved):
106
  # GENERATION FUNCTION FOR CHAT
107
  # ------------------------------------------------------------------
108
  def generate_answer(message: str, history):
109
- # 1) retrieve relevant chunks
110
  retrieved = retrieve_chunks(message, top_k=TOP_K)
111
-
112
- # 2) build prompt
113
  prompt = build_prompt(message, retrieved)
114
 
115
- # 3) tokenize
116
  inputs = tokenizer(
117
  prompt,
118
  return_tensors="pt",
@@ -120,45 +119,64 @@ def generate_answer(message: str, history):
120
  max_length=MAX_INPUT_LEN,
121
  ).to(device)
122
 
123
- # 4) generate with stronger anti-repetition settings
124
- with torch.no_grad():
125
- output_ids = model.generate(
126
- **inputs,
127
- max_new_tokens=MAX_NEW_TOKENS,
128
- do_sample=True,
129
- temperature=0.7,
130
- top_p=0.9,
131
- repetition_penalty=1.15,
132
- no_repeat_ngram_size=4,
133
- pad_token_id=tokenizer.eos_token_id,
134
- )
135
-
136
- # 5) strip the prompt part, decode only new tokens
137
- generated_ids = output_ids[0][inputs["input_ids"].shape[1]:]
138
- raw_text = tokenizer.decode(
139
- generated_ids,
140
- skip_special_tokens=True,
141
- ).strip()
142
-
143
- # 6) small cleanup: remove very long runs of the same line
144
- # (simple heuristic to kill the insane repetition cases)
145
- lines = [l.strip() for l in raw_text.splitlines() if l.strip()]
146
- cleaned = []
147
- last_line = None
148
- repeat_count = 0
149
- for l in lines:
150
- if l == last_line:
151
- repeat_count += 1
152
- if repeat_count >= 2:
153
- # skip extra repetitions
154
- continue
155
- else:
156
- repeat_count = 0
157
- last_line = l
158
- cleaned.append(l)
159
-
160
- answer = " ".join(cleaned).strip()
161
- return answer or raw_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
 
164
  # ------------------------------------------------------------------
@@ -170,6 +188,5 @@ demo = gr.ChatInterface(
170
  description="Klepetalnik, prilagojen na tvoje gradivo (CSV chunki).",
171
  )
172
 
173
-
174
  if __name__ == "__main__":
175
  demo.launch()
 
13
  # CONFIG – EDIT THESE TWO LINES TO MATCH YOUR REPOS
14
  # ------------------------------------------------------------------
15
  BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "cjvt/GaMS-1B-Chat")
 
16
  ADAPTER_ID = os.getenv("ADAPTER_ID", "janajankovic/autotrain-juhh6-uwiv9")
17
 
18
  CSV_PATH = "chunks_for_autotrain.csv"
19
+ TOP_K = 4
20
  MAX_INPUT_LEN = 2048
21
  MAX_NEW_TOKENS = 256
22
 
23
+ # Enforce non-empty answers
24
+ MIN_NEW_TOKENS = 32 # prevent immediate EOS / 1-4 word outputs
25
+ MIN_CHARS = 60 # require roughly one sentence worth of text
26
+ MAX_RETRIES = 2
27
+
28
 
29
  # ------------------------------------------------------------------
30
  # LOAD CSV CHUNKS + TF-IDF INDEX
 
34
 
35
  df = pd.read_csv(CSV_PATH)
36
 
 
37
  if "chunk" in df.columns:
38
  text_col = "chunk"
39
  elif "text" in df.columns:
40
  text_col = "text"
41
  else:
 
42
  text_col = df.columns[0]
43
 
44
  chunks = df[text_col].astype(str).tolist()
 
59
  if tokenizer.pad_token is None:
60
  tokenizer.pad_token = tokenizer.eos_token
61
 
62
+ # CRITICAL: if prompt is too long, keep the END (question + "Odgovor:")
63
+ tokenizer.truncation_side = "left"
64
+ tokenizer.padding_side = "left"
65
+
66
  base_model = AutoModelForCausalLM.from_pretrained(
67
  BASE_MODEL_ID,
68
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
69
  )
70
 
71
  model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
 
72
  model = model.merge_and_unload()
73
  model.to(device)
74
  model.eval()
 
85
 
86
 
87
  def retrieve_chunks(question: str, top_k: int = TOP_K):
 
88
  q_vec = vectorizer.transform([question])
89
  sims = cosine_similarity(q_vec, tfidf_matrix)[0]
90
  top_idx = sims.argsort()[::-1][:top_k]
 
93
 
94
  def build_prompt(question: str, retrieved):
95
  context = "\n\n---\n\n".join(retrieved)
 
96
  prompt = (
97
  f"{SYSTEM_PROMPT}\n\n"
98
  f"Kontekst:\n{context}\n\n"
 
109
  # GENERATION FUNCTION FOR CHAT
110
  # ------------------------------------------------------------------
111
  def generate_answer(message: str, history):
 
112
  retrieved = retrieve_chunks(message, top_k=TOP_K)
 
 
113
  prompt = build_prompt(message, retrieved)
114
 
 
115
  inputs = tokenizer(
116
  prompt,
117
  return_tensors="pt",
 
119
  max_length=MAX_INPUT_LEN,
120
  ).to(device)
121
 
122
+ def _generate_once(gen_kwargs: dict) -> str:
123
+ with torch.no_grad():
124
+ out = model.generate(**inputs, **gen_kwargs)
125
+ gen_ids = out[0][inputs["input_ids"].shape[1]:]
126
+ return tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
127
+
128
+ base_kwargs = dict(
129
+ max_new_tokens=MAX_NEW_TOKENS,
130
+ do_sample=True,
131
+ temperature=0.7,
132
+ top_p=0.9,
133
+ repetition_penalty=1.15,
134
+ no_repeat_ngram_size=4,
135
+ pad_token_id=tokenizer.eos_token_id,
136
+ eos_token_id=tokenizer.eos_token_id,
137
+ )
138
+
139
+ # Try to enforce minimum generation length (prevents 1–4 word answers).
140
+ try_kwargs = dict(base_kwargs)
141
+ try_kwargs["min_new_tokens"] = MIN_NEW_TOKENS
142
+
143
+ raw_text = ""
144
+ for _ in range(MAX_RETRIES + 1):
145
+ try:
146
+ raw_text = _generate_once(try_kwargs)
147
+ except TypeError:
148
+ # Older transformers: min_new_tokens not supported
149
+ raw_text = _generate_once(base_kwargs)
150
+
151
+ # Cleanup repeated identical lines
152
+ lines = [l.strip() for l in raw_text.splitlines() if l.strip()]
153
+ cleaned = []
154
+ last_line = None
155
+ rep = 0
156
+ for l in lines:
157
+ if l == last_line:
158
+ rep += 1
159
+ if rep >= 2:
160
+ continue
161
+ else:
162
+ rep = 0
163
+ last_line = l
164
+ cleaned.append(l)
165
+
166
+ answer = " ".join(cleaned).strip() or raw_text.strip()
167
+
168
+ # Accept if it looks like at least one sentence
169
+ if len(answer) >= MIN_CHARS and any(p in answer for p in ".!?"):
170
+ return answer
171
+
172
+ # Retry: loosen constraints a bit to avoid early stop / dead outputs
173
+ try_kwargs["temperature"] = min(0.95, try_kwargs.get("temperature", 0.7) + 0.15)
174
+ try_kwargs["top_p"] = min(0.98, try_kwargs.get("top_p", 0.9) + 0.05)
175
+ try_kwargs["repetition_penalty"] = max(1.05, try_kwargs.get("repetition_penalty", 1.15) - 0.05)
176
+ try_kwargs["no_repeat_ngram_size"] = max(2, try_kwargs.get("no_repeat_ngram_size", 4) - 1)
177
+
178
+ # Hard fallback: guarantees at least one full sentence
179
+ return "V podanih odlomkih ni dovolj informacij za zanesljiv odgovor na to vprašanje."
180
 
181
 
182
  # ------------------------------------------------------------------
 
188
  description="Klepetalnik, prilagojen na tvoje gradivo (CSV chunki).",
189
  )
190
 
 
191
  if __name__ == "__main__":
192
  demo.launch()