hamxaameer commited on
Commit
9ae102e
Β·
verified Β·
1 Parent(s): ed0b266

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -27
app.py CHANGED
@@ -44,9 +44,8 @@ def initialize_llm():
44
  logger.info("πŸ”„ Initializing FREE local language model...")
45
 
46
  BACKUP_MODELS = [
47
- "microsoft/Phi-3-mini-4k-instruct", # Primary - 3.8B, very efficient
48
- "google/flan-t5-large", # Backup - 780M, good quality
49
- "google/flan-t5-base", # Fallback - 250M, fast
50
  ]
51
 
52
  for model_name in BACKUP_MODELS:
@@ -54,15 +53,20 @@ def initialize_llm():
54
  logger.info(f" Trying {model_name}...")
55
  device = 0 if torch.cuda.is_available() else -1
56
 
 
 
 
57
  llm_client = pipeline(
58
- "text-generation",
59
  model=model_name,
60
  device=device,
61
- max_length=512,
62
  truncation=True,
 
63
  )
64
 
65
  CONFIG["llm_model"] = model_name
 
66
  logger.info(f"βœ… FREE LLM initialized: {model_name}")
67
  logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}")
68
  return llm_client
@@ -352,8 +356,13 @@ def generate_llm_answer(
352
  top_p = 0.97
353
  repetition_penalty = 1.25
354
 
355
- # Create prompt
356
- user_prompt = f"""[INST] Question: {query}
 
 
 
 
 
357
 
358
  Fashion Knowledge:
359
  {context_text}
@@ -363,17 +372,30 @@ Answer the question using the knowledge above. Be specific and helpful (100-250
363
  try:
364
  logger.info(f" β†’ Calling {CONFIG['llm_model']} (temp={temperature}, tokens={max_tokens})...")
365
 
366
- # Call pipeline
367
- output = llm_client(
368
- user_prompt,
369
- max_new_tokens=max_tokens,
370
- temperature=temperature,
371
- top_p=top_p,
372
- repetition_penalty=repetition_penalty,
373
- do_sample=True,
374
- return_full_text=False,
375
- pad_token_id=llm_client.tokenizer.eos_token_id
376
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  # Extract generated text
379
  response = output[0]['generated_text'].strip()
@@ -469,26 +491,60 @@ def generate_answer_langchain(
469
  # GRADIO INTERFACE
470
  # ============================================================================
471
 
472
- def fashion_chatbot(message: str, history: List[List[str]]) -> str:
473
  """
474
- Chatbot function for Gradio interface
475
  """
476
  try:
477
  if not message or not message.strip():
478
- return "Please ask a fashion-related question!"
 
 
 
 
479
 
480
- # Generate answer using RAG pipeline
481
- answer = generate_answer_langchain(
482
  message.strip(),
483
  vectorstore,
484
- llm_client
485
  )
486
 
487
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
  except Exception as e:
490
  logger.error(f"Error in chatbot: {e}")
491
- return f"Sorry, I encountered an error: {str(e)}"
492
 
493
  # ============================================================================
494
  # INITIALIZE AND LAUNCH
@@ -519,7 +575,7 @@ def startup():
519
  # Initialize on startup
520
  startup()
521
 
522
- # Create Gradio interface - simple version compatible with all Gradio versions
523
  demo = gr.ChatInterface(
524
  fn=fashion_chatbot,
525
  title="πŸ‘— Fashion Advisor - RAG System",
@@ -542,6 +598,10 @@ I can help with:
542
  "How to dress for a summer wedding?",
543
  "What's the best outfit for a university presentation?",
544
  ],
 
 
 
 
545
  )
546
 
547
  # Launch
 
44
  logger.info("πŸ”„ Initializing FREE local language model...")
45
 
46
  BACKUP_MODELS = [
47
+ "google/flan-t5-base", # Primary - 250M, very fast on CPU
48
+ "google/flan-t5-large", # Backup - 780M, slower but better
 
49
  ]
50
 
51
  for model_name in BACKUP_MODELS:
 
53
  logger.info(f" Trying {model_name}...")
54
  device = 0 if torch.cuda.is_available() else -1
55
 
56
+ # Use text2text-generation for T5 models (not text-generation)
57
+ task = "text2text-generation" if "t5" in model_name.lower() else "text-generation"
58
+
59
  llm_client = pipeline(
60
+ task,
61
  model=model_name,
62
  device=device,
63
+ max_length=300,
64
  truncation=True,
65
+ model_kwargs={"low_cpu_mem_usage": True, "use_cache": True} # Optimize for speed
66
  )
67
 
68
  CONFIG["llm_model"] = model_name
69
+ CONFIG["model_type"] = "t5" if "t5" in model_name.lower() else "instruct"
70
  logger.info(f"βœ… FREE LLM initialized: {model_name}")
71
  logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}")
72
  return llm_client
 
356
  top_p = 0.97
357
  repetition_penalty = 1.25
358
 
359
+ # Create prompt based on model type
360
+ if CONFIG.get("model_type") == "t5":
361
+ # T5 needs simple input-output format
362
+ user_prompt = f"Question: {query}\n\nContext: {context_text[:800]}\n\nProvide a helpful fashion answer:"
363
+ else:
364
+ # Instruct models use INST format
365
+ user_prompt = f"""[INST] Question: {query}
366
 
367
  Fashion Knowledge:
368
  {context_text}
 
372
  try:
373
  logger.info(f" β†’ Calling {CONFIG['llm_model']} (temp={temperature}, tokens={max_tokens})...")
374
 
375
+ # Call pipeline with model-specific parameters
376
+ if CONFIG.get("model_type") == "t5":
377
+ # T5 uses max_length not max_new_tokens
378
+ output = llm_client(
379
+ user_prompt,
380
+ max_length=150, # Even shorter for faster response
381
+ temperature=0.7, # Lower temp for consistency
382
+ top_p=0.9,
383
+ do_sample=True,
384
+ num_beams=1, # Disable beam search for speed
385
+ early_stopping=True
386
+ )
387
+ else:
388
+ # Other models use max_new_tokens
389
+ output = llm_client(
390
+ user_prompt,
391
+ max_new_tokens=max_tokens,
392
+ temperature=temperature,
393
+ top_p=top_p,
394
+ repetition_penalty=repetition_penalty,
395
+ do_sample=True,
396
+ return_full_text=False,
397
+ pad_token_id=llm_client.tokenizer.eos_token_id
398
+ )
399
 
400
  # Extract generated text
401
  response = output[0]['generated_text'].strip()
 
491
  # GRADIO INTERFACE
492
  # ============================================================================
493
 
494
+ def fashion_chatbot(message: str, history: List[List[str]]):
495
  """
496
+ Chatbot function for Gradio interface with streaming
497
  """
498
  try:
499
  if not message or not message.strip():
500
+ yield "Please ask a fashion-related question!"
501
+ return
502
+
503
+ # Show typing indicator
504
+ yield "πŸ” Searching fashion knowledge base..."
505
 
506
+ # Retrieve documents
507
+ retrieved_docs, confidence = retrieve_knowledge_langchain(
508
  message.strip(),
509
  vectorstore,
510
+ top_k=CONFIG["top_k"]
511
  )
512
 
513
+ if not retrieved_docs:
514
+ yield "I couldn't find relevant information to answer your question."
515
+ return
516
+
517
+ # Update status
518
+ yield f"πŸ’­ Generating answer (found {len(retrieved_docs)} relevant sources)..."
519
+
520
+ # Generate answer with multiple attempts
521
+ llm_answer = None
522
+ for attempt in range(1, 5):
523
+ logger.info(f"\n πŸ€– LLM Generation Attempt {attempt}/4")
524
+ llm_answer = generate_llm_answer(message.strip(), retrieved_docs, llm_client, attempt)
525
+
526
+ if llm_answer:
527
+ break
528
+
529
+ # Fallback if needed
530
+ if not llm_answer:
531
+ logger.error(f" βœ— All LLM attempts failed - using fallback")
532
+ llm_answer = synthesize_direct_answer(message.strip(), retrieved_docs)
533
+
534
+ # Stream the answer word by word for natural flow
535
+ words = llm_answer.split()
536
+ displayed_text = ""
537
+
538
+ for i, word in enumerate(words):
539
+ displayed_text += word + " "
540
+
541
+ # Yield every 2-3 words for smooth streaming
542
+ if i % 2 == 0 or i == len(words) - 1:
543
+ yield displayed_text.strip()
544
 
545
  except Exception as e:
546
  logger.error(f"Error in chatbot: {e}")
547
+ yield f"Sorry, I encountered an error: {str(e)}"
548
 
549
  # ============================================================================
550
  # INITIALIZE AND LAUNCH
 
575
  # Initialize on startup
576
  startup()
577
 
578
+ # Create Gradio interface with streaming enabled
579
  demo = gr.ChatInterface(
580
  fn=fashion_chatbot,
581
  title="πŸ‘— Fashion Advisor - RAG System",
 
598
  "How to dress for a summer wedding?",
599
  "What's the best outfit for a university presentation?",
600
  ],
601
+ cache_examples=False, # Don't cache for fresh responses
602
+ retry_btn="πŸ”„ Retry",
603
+ undo_btn="↩️ Undo",
604
+ clear_btn="πŸ—‘οΈ Clear",
605
  )
606
 
607
  # Launch