hamxaameer commited on
Commit
0e156ba
Β·
verified Β·
1 Parent(s): 486261d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -54
app.py CHANGED
@@ -273,85 +273,152 @@ def generate_llm_answer(
273
 
274
  context_text = "\n\n".join(context_parts)
275
 
276
- # Request a long, natural-flowing answer (target 400-700 words).
277
- # We keep context trimmed to avoid exceeding input limits, but ask the model
278
- # to produce an extended answer. Use sampling for natural flow on CPU.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  if attempt == 1:
280
  temperature = 0.70
281
- max_new_tokens = 1000
282
  top_p = 0.92
283
  repetition_penalty = 1.1
284
  else:
285
  temperature = 0.82
286
- max_new_tokens = 1200
287
  top_p = 0.95
288
  repetition_penalty = 1.15
289
 
290
- user_prompt = f"""Please write a clear, natural-flowing, well-structured fashion answer using the context below.
 
291
 
292
- Question: {query}
 
 
 
293
 
294
- Context (use where helpful):
295
- {context_text[:1400]}
296
 
297
- Requirements:
298
- - Write between 400 and 700 words (aim for natural paragraphs, not a list of short sentences).
299
- - Use the provided context where relevant and add practical, actionable advice.
300
- - Keep a friendly, expert tone and avoid hedging phrases like "I can't" or "I don't know".
 
 
 
301
 
302
- Answer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  """
304
 
305
- try:
306
- logger.info(f" β†’ Calling {CONFIG.get('llm_model', 'LLM')} (temp={temperature}, max_new_tokens={max_new_tokens})...")
307
-
308
- # Use max_new_tokens for the generated output length (safer than max_length here).
309
- output = llm_client(
310
- user_prompt,
311
- max_new_tokens=max_new_tokens,
312
- temperature=temperature,
313
- top_p=top_p,
314
- do_sample=True,
315
- num_beams=1,
316
- repetition_penalty=repetition_penalty,
317
- early_stopping=False
318
- )
319
 
320
- # transformers text2text pipeline returns a list of dicts; key is often 'generated_text'
321
- response = output[0].get('generated_text', '') if isinstance(output, list) and output else str(output).strip()
322
- response = response.strip()
323
 
324
- if not response:
325
- logger.warning(f" βœ— Empty response (attempt {attempt})")
326
- return None
 
327
 
328
- # Word-count validation: ensure long-form output (target 400-700 words)
329
- words = response.split()
330
- word_count = len(words)
331
 
332
- if word_count < 380:
333
- logger.warning(f" βœ— Response too short ({word_count} words) β€” expected ~400-700")
334
- return None
335
 
336
- # If excessively long, truncate to 700 words for safety and readability
337
- if word_count > 750:
338
- logger.info(f" ⚠️ Response very long ({word_count} words). Truncating to 700 words.")
339
- response = ' '.join(words[:700]) + '...'
340
- word_count = 700
341
 
342
- # Filter out apology/hedging starts
343
- apology_phrases = ["i cannot", "i can't", "i'm sorry", "i apologize", "i don't have"]
344
- if any(phrase in response.lower()[:200] for phrase in apology_phrases):
345
- logger.warning(f" βœ— Apology/hedging detected in response start")
346
- return None
347
 
348
- logger.info(f" βœ… Generated long-form answer ({word_count} words, {len(response)} chars)")
349
- return response
350
 
351
- except Exception as e:
352
- logger.error(f" βœ— Generation error: {e}")
353
  return None
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  def generate_answer_langchain(
356
  query: str,
357
  vectorstore,
 
273
 
274
  context_text = "\n\n".join(context_parts)
275
 
276
+ # Primary strategy: try a single long-form generation first. If that fails
277
+ # (too short or truncated), fall back to an iterative multi-pass generator
278
+ # that appends continuation chunks until we reach the target word count.
279
+
280
+ target_min_words = 400
281
+ target_max_words = 700
282
+ chunk_target_words = 200
283
+ max_iterations = 4
284
+
285
+ def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty):
286
+ logger.info(f" β†’ Model call (temp={temperature}, max_new_tokens={max_new_tokens})")
287
+ try:
288
+ out = llm_client(
289
+ prompt,
290
+ max_new_tokens=max_new_tokens,
291
+ temperature=temperature,
292
+ top_p=top_p,
293
+ do_sample=True,
294
+ num_beams=1,
295
+ repetition_penalty=repetition_penalty,
296
+ early_stopping=False
297
+ )
298
+ if isinstance(out, list) and out:
299
+ return out[0].get('generated_text', '') if isinstance(out[0], dict) else str(out[0])
300
+ return str(out)
301
+ except Exception as e:
302
+ logger.error(f" βœ— Model call error: {e}")
303
+ return ''
304
+
305
+ # Build initial prompt
306
+ base_prompt = f"""Please write a clear, natural-flowing, well-structured fashion answer using the context below.
307
+
308
+ Question: {query}
309
+
310
+ Context (use where helpful):
311
+ {context_text[:1200]}
312
+
313
+ Requirements:
314
+ - Aim for a long-form answer ~{target_min_words}-{target_max_words} words, structured in paragraphs.
315
+ - Use the provided context where relevant and add practical, actionable advice.
316
+ - Keep a friendly, expert tone and avoid hedging phrases like "I can't" or "I don't know".
317
+
318
+ Answer:
319
+ """
320
+
321
+ # First attempt: single call with a medium-large token budget (may be limited by model)
322
  if attempt == 1:
323
  temperature = 0.70
324
+ max_new_tokens = 600
325
  top_p = 0.92
326
  repetition_penalty = 1.1
327
  else:
328
  temperature = 0.82
329
+ max_new_tokens = 800
330
  top_p = 0.95
331
  repetition_penalty = 1.15
332
 
333
+ initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty)
334
+ response = (initial_output or '').strip()
335
 
336
+ # Basic sanity checks
337
+ if not response:
338
+ logger.warning(" βœ— Empty initial response")
339
+ response = ''
340
 
341
+ words = response.split()
342
+ word_count = len(words)
343
 
344
+ # If single-shot succeeded, validate length and return
345
+ if word_count >= target_min_words:
346
+ if word_count > target_max_words:
347
+ response = ' '.join(words[:target_max_words]) + '...'
348
+ word_count = target_max_words
349
+ logger.info(f" βœ… Single-shot generated {word_count} words")
350
+ return response
351
 
352
+ # Otherwise, try iterative continuation to build up to the target
353
+ accumulated = response
354
+ prev_word_count = word_count
355
+
356
+ for i in range(max_iterations):
357
+ remaining = max(0, target_min_words - len(accumulated.split()))
358
+ if remaining <= 0:
359
+ break
360
+
361
+ # Ask the model to continue without repeating previous content
362
+ continue_prompt = f"""Continue the previous answer in the same tone and style to add about {min(chunk_target_words, remaining)} words.
363
+
364
+ Do not repeat sentences already present. Keep paragraphs natural and connected.
365
+
366
+ Previous answer:
367
+ {accumulated}
368
+
369
+ Continue:
370
  """
371
 
372
+ # Slightly higher temperature on continuations to encourage richer text
373
+ cont_output = call_model(continue_prompt, max_new_tokens=450, temperature=0.78, top_p=0.93, repetition_penalty=1.08)
374
+ cont_text = (cont_output or '').strip()
 
 
 
 
 
 
 
 
 
 
 
375
 
376
+ if not cont_text:
377
+ logger.warning(f" βœ— Continuation {i+1} returned empty β€” stopping")
378
+ break
379
 
380
+ # Avoid trivial repeats: if continuation repeats the accumulated text, stop
381
+ if cont_text in accumulated or accumulated.endswith(cont_text[:50]):
382
+ logger.warning(f" βœ— Continuation {i+1} appears repetitive β€” stopping")
383
+ break
384
 
385
+ # Append and normalize spacing
386
+ accumulated = accumulated.rstrip() + '\n\n' + cont_text
 
387
 
388
+ current_word_count = len(accumulated.split())
389
+ logger.info(f" β†’ After continuation {i+1}, words={current_word_count}")
 
390
 
391
+ # Stop early if we've reached or exceeded the minimum target
392
+ if current_word_count >= target_min_words:
393
+ break
 
 
394
 
395
+ # Safety: if no progress, break
396
+ if current_word_count == prev_word_count:
397
+ logger.warning(" βœ— No progress from continuation β€” stopping")
398
+ break
399
+ prev_word_count = current_word_count
400
 
401
+ final_words = accumulated.split()
402
+ final_count = len(final_words)
403
 
404
+ if final_count < target_min_words:
405
+ logger.warning(f" βœ— Final answer too short ({final_count} words) after continuations")
406
  return None
407
 
408
+ if final_count > target_max_words:
409
+ logger.info(f" ⚠️ Final answer long ({final_count} words). Truncating to {target_max_words} words.")
410
+ accumulated = ' '.join(final_words[:target_max_words]) + '...'
411
+ final_count = target_max_words
412
+
413
+ # Final check for apology/hedging at start
414
+ apology_phrases = ["i cannot", "i can't", "i'm sorry", "i apologize", "i don't have"]
415
+ if any(phrase in accumulated.lower()[:200] for phrase in apology_phrases):
416
+ logger.warning(" βœ— Apology/hedging detected in final answer")
417
+ return None
418
+
419
+ logger.info(f" βœ… Built long-form answer ({final_count} words)")
420
+ return accumulated
421
+
422
  def generate_answer_langchain(
423
  query: str,
424
  vectorstore,