|
|
""" |
|
|
Fashion Advisor RAG - Hugging Face Deployment |
|
|
Complete RAG system with FAISS vector store and local LLM |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple, Dict, Optional |
|
|
import pickle |
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import pipeline |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import requests |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.schema import Document |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG = { |
|
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2", |
|
|
"llm_model": None, |
|
|
"vector_store_path": ".", |
|
|
"top_k": 15, |
|
|
"temperature": 0.75, |
|
|
"max_tokens": 350, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USE_REMOTE_LLM = False |
|
|
REMOTE_LLM_MODEL = os.environ.get("HF_INFERENCE_MODEL", "tiiuae/falcon-7b-instruct") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_INFERENCE_API_KEY = os.environ.get("HF_INFERENCE_API_KEY") |
|
|
if not HF_INFERENCE_API_KEY: |
|
|
try: |
|
|
token_path = Path("hf_token.txt") |
|
|
if token_path.exists(): |
|
|
HF_INFERENCE_API_KEY = token_path.read_text(encoding="utf-8").strip() |
|
|
logger.info("Loaded HF token from hf_token.txt (ensure this file is private and not committed)") |
|
|
except Exception: |
|
|
logger.warning("Could not read hf_token.txt for HF token") |
|
|
|
|
|
if HF_INFERENCE_API_KEY: |
|
|
USE_REMOTE_LLM = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_llm(): |
|
|
|
|
|
|
|
|
global USE_REMOTE_LLM, REMOTE_LLM_MODEL |
|
|
if USE_REMOTE_LLM: |
|
|
logger.info(f"π Using remote Hugging Face Inference model: {REMOTE_LLM_MODEL}") |
|
|
CONFIG["llm_model"] = REMOTE_LLM_MODEL |
|
|
CONFIG["model_type"] = "remote" |
|
|
return None |
|
|
|
|
|
logger.info("π Initializing FREE local language model...") |
|
|
model_name = "google/flan-t5-large" |
|
|
|
|
|
try: |
|
|
logger.info(f" Loading {model_name}...") |
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
model_kwargs = {"low_cpu_mem_usage": True} |
|
|
|
|
|
llm_client = pipeline( |
|
|
"text2text-generation", |
|
|
model=model_name, |
|
|
device=device, |
|
|
model_kwargs=model_kwargs |
|
|
) |
|
|
|
|
|
CONFIG["llm_model"] = model_name |
|
|
CONFIG["model_type"] = "t5" |
|
|
logger.info(f"β
LLM initialized: {model_name}") |
|
|
logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}") |
|
|
return llm_client |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load model: {str(e)}") |
|
|
raise Exception(f"Failed to initialize LLM: {str(e)}") |
|
|
|
|
|
|
|
|
def remote_generate(prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9) -> str: |
|
|
"""Call the Hugging Face Inference API for remote generation. Requires |
|
|
`HF_INFERENCE_API_KEY` env var to be set and a model name in |
|
|
`REMOTE_LLM_MODEL`. |
|
|
""" |
|
|
if not HF_INFERENCE_API_KEY: |
|
|
raise Exception("HF_INFERENCE_API_KEY not set for remote generation") |
|
|
|
|
|
|
|
|
router_url = f"https://router.huggingface.co/models/{REMOTE_LLM_MODEL}" |
|
|
old_url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}" |
|
|
headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}", "Accept": "application/json"} |
|
|
payload = { |
|
|
"inputs": prompt, |
|
|
"parameters": { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"return_full_text": False |
|
|
} |
|
|
} |
|
|
|
|
|
logger.info(f" β Remote inference request to router {REMOTE_LLM_MODEL} (tokens={max_new_tokens}, temp={temperature})") |
|
|
try: |
|
|
r = requests.post(router_url, headers=headers, json=payload, timeout=120) |
|
|
except Exception as e: |
|
|
logger.error(f" β Remote router request failed: {e}") |
|
|
|
|
|
try: |
|
|
logger.info(" β Attempting legacy api-inference endpoint as fallback") |
|
|
r = requests.post(old_url, headers=headers, json=payload, timeout=120) |
|
|
except Exception as e2: |
|
|
logger.error(f" β Legacy endpoint request failed: {e2}") |
|
|
return "" |
|
|
|
|
|
if r.status_code != 200: |
|
|
logger.error(f" β Remote inference error {r.status_code}: {r.text[:200]}") |
|
|
return "" |
|
|
|
|
|
result = r.json() |
|
|
if isinstance(result, dict) and result.get("error"): |
|
|
logger.error(f" β Remote inference returned error: {result.get('error')}") |
|
|
return "" |
|
|
|
|
|
|
|
|
if isinstance(result, list) and result: |
|
|
|
|
|
first = result[0] |
|
|
if isinstance(first, dict): |
|
|
return first.get("generated_text", "").strip() |
|
|
return str(first).strip() |
|
|
|
|
|
if isinstance(result, dict) and "generated_text" in result: |
|
|
return result["generated_text"].strip() |
|
|
|
|
|
return str(result).strip() |
|
|
|
|
|
def initialize_embeddings(): |
|
|
logger.info("π Initializing embeddings model...") |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name=CONFIG["embedding_model"], |
|
|
model_kwargs={'device': 'cpu'}, |
|
|
encode_kwargs={'normalize_embeddings': True} |
|
|
) |
|
|
|
|
|
logger.info(f"β
Embeddings initialized: {CONFIG['embedding_model']}") |
|
|
return embeddings |
|
|
|
|
|
def load_vector_store(embeddings): |
|
|
logger.info("π Loading FAISS vector store...") |
|
|
|
|
|
vector_store_path = CONFIG["vector_store_path"] |
|
|
index_file = os.path.join(vector_store_path, "index.faiss") |
|
|
pkl_file = os.path.join(vector_store_path, "index.pkl") |
|
|
|
|
|
if not os.path.exists(index_file): |
|
|
raise FileNotFoundError(f"FAISS index file not found: {index_file}") |
|
|
|
|
|
if not os.path.exists(pkl_file): |
|
|
raise FileNotFoundError(f"FAISS metadata file not found: {pkl_file}") |
|
|
|
|
|
logger.info(f"β
Found index.faiss ({os.path.getsize(index_file)/1024/1024:.2f} MB)") |
|
|
logger.info(f"β
Found index.pkl ({os.path.getsize(pkl_file)/1024:.2f} KB)") |
|
|
|
|
|
try: |
|
|
vectorstore = FAISS.load_local( |
|
|
vector_store_path, |
|
|
embeddings, |
|
|
allow_dangerous_deserialization=True |
|
|
) |
|
|
logger.info(f"β
FAISS vector store loaded successfully") |
|
|
return vectorstore |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ Pydantic compatibility issue: {str(e)[:100]}") |
|
|
logger.info("π Applying Pydantic monkey-patch and retrying...") |
|
|
|
|
|
try: |
|
|
import pydantic.v1.main as pydantic_main |
|
|
original_setstate = pydantic_main.BaseModel.__setstate__ |
|
|
|
|
|
def patched_setstate(self, state): |
|
|
if '__fields_set__' not in state: |
|
|
state['__fields_set__'] = set(state.get('__dict__', {}).keys()) |
|
|
return original_setstate(self, state) |
|
|
|
|
|
pydantic_main.BaseModel.__setstate__ = patched_setstate |
|
|
logger.info(" β
Pydantic monkey-patch applied") |
|
|
|
|
|
except Exception as patch_error: |
|
|
logger.warning(f" β οΈ Pydantic patch failed: {patch_error}") |
|
|
|
|
|
try: |
|
|
vectorstore = FAISS.load_local( |
|
|
vector_store_path, |
|
|
embeddings, |
|
|
allow_dangerous_deserialization=True |
|
|
) |
|
|
logger.info(f"β
FAISS vector store loaded with Pydantic patch") |
|
|
return vectorstore |
|
|
|
|
|
except Exception as e2: |
|
|
logger.error(f" β Still failed after patch: {str(e2)[:100]}") |
|
|
logger.info("π Using manual reconstruction (last resort)...") |
|
|
|
|
|
import faiss |
|
|
from langchain_community.docstore.in_memory import InMemoryDocstore |
|
|
|
|
|
index = faiss.read_index(index_file) |
|
|
logger.info(f" β
FAISS index loaded") |
|
|
|
|
|
with open(pkl_file, "rb") as f: |
|
|
import re |
|
|
raw_bytes = f.read() |
|
|
logger.info(f" Read {len(raw_bytes)} bytes from pickle") |
|
|
|
|
|
text_pattern = rb'([A-Za-z0-9\s\.\,\;\:\!\?\-\'\"\(\)]{50,})' |
|
|
matches = re.findall(text_pattern, raw_bytes) |
|
|
|
|
|
if len(matches) > 100: |
|
|
logger.info(f" Found {len(matches)} potential document fragments") |
|
|
|
|
|
documents = [] |
|
|
for idx, match in enumerate(matches[:5000]): |
|
|
try: |
|
|
content = match.decode('utf-8', errors='ignore').strip() |
|
|
if len(content) >= 100: |
|
|
doc = Document( |
|
|
page_content=content, |
|
|
metadata={"source": "reconstructed", "id": idx} |
|
|
) |
|
|
documents.append(doc) |
|
|
except: |
|
|
continue |
|
|
|
|
|
if len(documents) < 100: |
|
|
raise Exception(f"Only extracted {len(documents)} documents, need at least 100") |
|
|
|
|
|
logger.info(f" β
Extracted {len(documents)} high-quality documents") |
|
|
logger.info(f" π Rebuilding FAISS index from scratch...") |
|
|
|
|
|
vectorstore = FAISS.from_documents( |
|
|
documents=documents, |
|
|
embedding=embeddings |
|
|
) |
|
|
|
|
|
logger.info(f"β
FAISS vector store rebuilt from {len(documents)} documents") |
|
|
return vectorstore |
|
|
else: |
|
|
raise Exception("Could not extract enough document content from pickle") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_extractive_answer(query: str, retrieved_docs: List[Document]) -> Optional[str]: |
|
|
"""Build a long-form answer from retrieved documents using extractive |
|
|
selection + templated transitions. This avoids calling the LLM when it |
|
|
repeatedly fails or returns very short outputs. |
|
|
""" |
|
|
logger.info(f"π§ Running extractive fallback for: '{query}'") |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
all_text = "\n\n".join([d.page_content for d in retrieved_docs]) |
|
|
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', all_text) |
|
|
sentences = [s.strip() for s in sentences if len(s.strip()) > 30] |
|
|
|
|
|
if not sentences: |
|
|
logger.warning(" β No sentences found in retrieved documents for extractive fallback") |
|
|
return None |
|
|
|
|
|
|
|
|
query_tokens = set(re.findall(r"\w+", query.lower())) |
|
|
fashion_keywords = set(["outfit","wear","wardrobe","style","colors","color","layer","layering", |
|
|
"blazer","trousers","dress","shirt","shoes","boots","sweater","jacket", |
|
|
"care","wash","dry","clean","wool","cotton","silk","linen","fit","tailor", |
|
|
"versatile","neutral","accessory","belt","bag","occasion","season","fall"]) |
|
|
keywords = query_tokens.union(fashion_keywords) |
|
|
|
|
|
scored = [] |
|
|
for s in sentences: |
|
|
s_tokens = set(re.findall(r"\w+", s.lower())) |
|
|
score = len(s_tokens & keywords) |
|
|
|
|
|
score += min(3, len(s.split()) // 20) |
|
|
scored.append((score, s)) |
|
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
top_sentences = [s for _, s in scored[:60]] |
|
|
|
|
|
|
|
|
def pick(n, start=0): |
|
|
return top_sentences[start:start+n] |
|
|
|
|
|
intro = [] |
|
|
intro.extend(pick(2, 0)) |
|
|
key_items = pick(8, 2) |
|
|
styling = pick(8, 10) |
|
|
care = pick(6, 18) |
|
|
conclusion = pick(4, 24) |
|
|
|
|
|
|
|
|
template_intro = f"Here's a detailed answer to '{query}'. I'll cover essential wardrobe items, styling tips, and care advice so you can apply these suggestions practically." |
|
|
|
|
|
|
|
|
care_text = "\n\n".join(care) |
|
|
if "dry clean" not in care_text.lower() and "hand wash" not in care_text.lower(): |
|
|
care_text += "\n\nDry clean or hand wash in cold water with wool-specific detergent. Never wring out wool - gently squeeze excess water and lay flat to dry on a towel." |
|
|
|
|
|
parts = [] |
|
|
parts.append(template_intro) |
|
|
if intro: |
|
|
parts.append(" ".join(intro)) |
|
|
if key_items: |
|
|
parts.append("Key wardrobe items to prioritize:") |
|
|
parts.append(" ".join(key_items)) |
|
|
if styling: |
|
|
parts.append("Practical styling tips:") |
|
|
parts.append(" ".join(styling)) |
|
|
if care_text: |
|
|
parts.append("Care & maintenance:") |
|
|
parts.append(care_text) |
|
|
if conclusion: |
|
|
parts.append("Wrapping up:") |
|
|
parts.append(" ".join(conclusion)) |
|
|
|
|
|
|
|
|
answer = "\n\n".join(parts) |
|
|
|
|
|
|
|
|
words = answer.split() |
|
|
word_count = len(words) |
|
|
|
|
|
|
|
|
if word_count < 380: |
|
|
logger.info(f" β Extractive answer short ({word_count} words). Appending templated paragraphs.") |
|
|
extra_paragraphs = [] |
|
|
extra_paragraphs.append("A reliable strategy is to build around versatile, neutral pieces: a well-fitted blazer, tailored trousers, a versatile dress, and quality shoes. These items can be mixed and matched for many occasions.") |
|
|
extra_paragraphs.append("Focus on fit and fabric: ensure key items are well-tailored, prioritize breathable fabrics for comfort, and choose merino or wool blends for colder seasons to layer effectively.") |
|
|
extra_paragraphs.append("Layering is essential for transitional weather; combine a lightweight sweater under a jacket, and carry a scarf for added warmth and visual interest.") |
|
|
extra_paragraphs.append("Accessories like belts, a structured bag, and minimal jewelry can elevate basic outfits without extra effort. Neutral colors increase versatility and pair well with bolder accents.") |
|
|
answer += "\n\n" + "\n\n".join(extra_paragraphs) |
|
|
words = answer.split() |
|
|
word_count = len(words) |
|
|
|
|
|
|
|
|
if word_count > 750: |
|
|
words = words[:700] |
|
|
answer = " ".join(words) + '...' |
|
|
word_count = 700 |
|
|
|
|
|
logger.info(f" β
Extractive answer ready ({word_count} words)") |
|
|
return answer |
|
|
|
|
|
|
|
|
def scaffold_and_polish(query: str, retrieved_docs: List[Document], llm_client) -> Optional[str]: |
|
|
"""Create a concise scaffold (approx 150-220 words) from retrieved docs, |
|
|
then ask the remote (or local) LLM to expand and polish it into a |
|
|
320-420 word expert answer. Returns None if polishing fails. |
|
|
""" |
|
|
logger.info(f"π¨ Building scaffold for polish: '{query}'") |
|
|
import re |
|
|
|
|
|
|
|
|
all_text = "\n\n".join([d.page_content for d in retrieved_docs[:12]]) |
|
|
sentences = re.split(r'(?<=[.!?])\s+', all_text) |
|
|
sentences = [s.strip() for s in sentences if len(s.strip()) > 30] |
|
|
if not sentences: |
|
|
logger.warning(" β No sentences to build scaffold") |
|
|
return None |
|
|
|
|
|
|
|
|
query_tokens = set(re.findall(r"\w+", query.lower())) |
|
|
fashion_keywords = set(["outfit","wear","wardrobe","style","colors","layer","blazer", |
|
|
"trousers","dress","shoes","sweater","jacket","care","wool","fit", |
|
|
"tailor","neutral","accessory","season","fall"]) |
|
|
keywords = query_tokens.union(fashion_keywords) |
|
|
|
|
|
scored = [] |
|
|
for s in sentences: |
|
|
s_tokens = set(re.findall(r"\w+", s.lower())) |
|
|
score = len(s_tokens & keywords) |
|
|
score += min(2, len(s.split()) // 30) |
|
|
scored.append((score, s)) |
|
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
scaffold_parts = [] |
|
|
word_count = 0 |
|
|
for _, s in scored: |
|
|
scaffold_parts.append(s) |
|
|
word_count = len(" ".join(scaffold_parts).split()) |
|
|
if word_count >= 180: |
|
|
break |
|
|
|
|
|
scaffold = "\n\n".join(scaffold_parts).strip() |
|
|
if not scaffold: |
|
|
logger.warning(" β Scaffold empty after selection") |
|
|
return None |
|
|
|
|
|
|
|
|
polish_prompt = f"""Please rewrite and expand the draft below into a clear, expert, natural-flowing answer of about 320-420 words to the question: {query} |
|
|
|
|
|
Requirements: |
|
|
- Keep paragraphs natural and connected. |
|
|
- Preserve factual content from the draft and avoid inventing unsupported facts. |
|
|
- Use a friendly, expert tone and provide practical, actionable advice. |
|
|
|
|
|
Draft: |
|
|
{scaffold} |
|
|
|
|
|
Answer: |
|
|
""" |
|
|
|
|
|
logger.info(" β Polishing scaffold with LLM") |
|
|
try: |
|
|
if USE_REMOTE_LLM: |
|
|
polished = remote_generate(polish_prompt, max_new_tokens=600, temperature=0.72, top_p=0.92) |
|
|
else: |
|
|
out = llm_client(polish_prompt, max_new_tokens=600, temperature=0.72, top_p=0.92, do_sample=True, num_beams=1) |
|
|
polished = out[0].get('generated_text', '') if isinstance(out, list) and out else str(out) |
|
|
polished = polished.strip() |
|
|
except Exception as e: |
|
|
logger.error(f" β Polishing error: {e}") |
|
|
return None |
|
|
|
|
|
if not polished: |
|
|
logger.warning(" β Polished output empty") |
|
|
return None |
|
|
|
|
|
final_words = polished.split() |
|
|
fw = len(final_words) |
|
|
if fw < 300: |
|
|
logger.warning(f" β Polished output too short ({fw} words)") |
|
|
return None |
|
|
if fw > 460: |
|
|
polished = ' '.join(final_words[:420]) + '...' |
|
|
|
|
|
logger.info(f" β
Polished answer ready ({len(polished.split())} words)") |
|
|
return polished |
|
|
|
|
|
|
|
|
def retrieve_knowledge_langchain( |
|
|
query: str, |
|
|
vectorstore, |
|
|
top_k: int = 15 |
|
|
) -> Tuple[List[Document], float]: |
|
|
logger.info(f"π Retrieving knowledge for: '{query}'") |
|
|
|
|
|
query_variants = [ |
|
|
query, |
|
|
f"fashion advice clothing outfit style for {query}", |
|
|
] |
|
|
|
|
|
all_docs = [] |
|
|
|
|
|
for variant in query_variants: |
|
|
try: |
|
|
docs_and_scores = vectorstore.similarity_search_with_score(variant, k=top_k) |
|
|
|
|
|
for doc, score in docs_and_scores: |
|
|
similarity = 1.0 / (1.0 + score) |
|
|
doc.metadata['similarity'] = similarity |
|
|
doc.metadata['query_variant'] = variant |
|
|
all_docs.append(doc) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Retrieval error for variant '{variant}': {e}") |
|
|
|
|
|
unique_docs = {} |
|
|
for doc in all_docs: |
|
|
content_key = doc.page_content[:100] |
|
|
if content_key not in unique_docs: |
|
|
unique_docs[content_key] = doc |
|
|
else: |
|
|
if doc.metadata.get('similarity', 0) > unique_docs[content_key].metadata.get('similarity', 0): |
|
|
unique_docs[content_key] = doc |
|
|
|
|
|
final_docs = list(unique_docs.values()) |
|
|
final_docs.sort(key=lambda x: x.metadata.get('similarity', 0), reverse=True) |
|
|
|
|
|
if final_docs: |
|
|
avg_similarity = sum(d.metadata.get('similarity', 0) for d in final_docs) / len(final_docs) |
|
|
confidence = min(avg_similarity, 1.0) |
|
|
else: |
|
|
confidence = 0.0 |
|
|
|
|
|
logger.info(f"β
Retrieved {len(final_docs)} unique documents (confidence: {confidence:.2f})") |
|
|
|
|
|
return final_docs, confidence |
|
|
|
|
|
def generate_llm_answer( |
|
|
query: str, |
|
|
retrieved_docs: List[Document], |
|
|
llm_client, |
|
|
attempt: int = 1 |
|
|
) -> Optional[str]: |
|
|
|
|
|
if not llm_client and not USE_REMOTE_LLM: |
|
|
logger.error(" β LLM client not initialized and remote inference disabled") |
|
|
return None |
|
|
|
|
|
query_lower = query.lower() |
|
|
query_words = set(query_lower.split()) |
|
|
|
|
|
scored_docs = [] |
|
|
for doc in retrieved_docs[:20]: |
|
|
content = doc.page_content.lower() |
|
|
doc_words = set(content.split()) |
|
|
overlap = len(query_words.intersection(doc_words)) |
|
|
|
|
|
if doc.metadata.get('verified', False): |
|
|
overlap += 10 |
|
|
|
|
|
if len(doc.page_content) > 200: |
|
|
overlap += 3 |
|
|
|
|
|
scored_docs.append((doc, overlap)) |
|
|
|
|
|
scored_docs.sort(key=lambda x: x[1], reverse=True) |
|
|
top_docs = [doc[0] for doc in scored_docs[:8]] |
|
|
|
|
|
context_parts = [] |
|
|
for doc in top_docs: |
|
|
content = doc.page_content.strip() |
|
|
if len(content) > 400: |
|
|
content = content[:400] + "..." |
|
|
context_parts.append(content) |
|
|
|
|
|
context_text = "\n\n".join(context_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_min_words = 320 |
|
|
target_max_words = 420 |
|
|
chunk_target_words = 140 |
|
|
max_iterations = 4 |
|
|
|
|
|
def call_model(prompt, max_new_tokens, temperature, top_p, repetition_penalty): |
|
|
logger.info(f" β Model call (temp={temperature}, max_new_tokens={max_new_tokens})") |
|
|
try: |
|
|
if USE_REMOTE_LLM: |
|
|
|
|
|
return remote_generate(prompt, max_new_tokens, temperature, top_p) |
|
|
|
|
|
out = llm_client( |
|
|
prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=True, |
|
|
num_beams=1, |
|
|
repetition_penalty=repetition_penalty, |
|
|
early_stopping=False |
|
|
) |
|
|
if isinstance(out, list) and out: |
|
|
return out[0].get('generated_text', '') if isinstance(out[0], dict) else str(out[0]) |
|
|
return str(out) |
|
|
except Exception as e: |
|
|
logger.error(f" β Model call error: {e}") |
|
|
return '' |
|
|
|
|
|
|
|
|
base_prompt = f"""Please write a clear, natural-flowing, well-structured fashion answer using the context below. |
|
|
|
|
|
Question: {query} |
|
|
|
|
|
Context (use where helpful): |
|
|
{context_text[:1200]} |
|
|
|
|
|
Requirements: |
|
|
- Aim for a long-form answer ~{target_min_words}-{target_max_words} words, structured in paragraphs. |
|
|
- Use the provided context where relevant and add practical, actionable advice. |
|
|
- Keep a friendly, expert tone and avoid hedging phrases like "I can't" or "I don't know". |
|
|
|
|
|
Answer: |
|
|
""" |
|
|
|
|
|
|
|
|
if attempt == 1: |
|
|
temperature = 0.70 |
|
|
max_new_tokens = 600 |
|
|
top_p = 0.92 |
|
|
repetition_penalty = 1.1 |
|
|
else: |
|
|
temperature = 0.82 |
|
|
max_new_tokens = 800 |
|
|
top_p = 0.95 |
|
|
repetition_penalty = 1.15 |
|
|
|
|
|
initial_output = call_model(base_prompt, max_new_tokens, temperature, top_p, repetition_penalty) |
|
|
response = (initial_output or '').strip() |
|
|
|
|
|
|
|
|
if not response: |
|
|
logger.warning(" β Empty initial response") |
|
|
response = '' |
|
|
|
|
|
words = response.split() |
|
|
word_count = len(words) |
|
|
|
|
|
|
|
|
if word_count >= target_min_words: |
|
|
if word_count > target_max_words: |
|
|
response = ' '.join(words[:target_max_words]) + '...' |
|
|
word_count = target_max_words |
|
|
logger.info(f" β
Single-shot generated {word_count} words") |
|
|
return response |
|
|
|
|
|
|
|
|
accumulated = response |
|
|
prev_word_count = word_count |
|
|
|
|
|
for i in range(max_iterations): |
|
|
remaining = max(0, target_min_words - len(accumulated.split())) |
|
|
if remaining <= 0: |
|
|
break |
|
|
|
|
|
|
|
|
continue_prompt = f"""Continue the previous answer in the same tone and style to add about {min(chunk_target_words, remaining)} words. |
|
|
|
|
|
Do not repeat sentences already present. Keep paragraphs natural and connected. |
|
|
|
|
|
Previous answer: |
|
|
{accumulated} |
|
|
|
|
|
Continue: |
|
|
""" |
|
|
|
|
|
|
|
|
cont_output = call_model(continue_prompt, max_new_tokens=450, temperature=0.78, top_p=0.93, repetition_penalty=1.08) |
|
|
cont_text = (cont_output or '').strip() |
|
|
|
|
|
if not cont_text: |
|
|
logger.warning(f" β Continuation {i+1} returned empty β stopping") |
|
|
break |
|
|
|
|
|
|
|
|
if cont_text in accumulated or accumulated.endswith(cont_text[:50]): |
|
|
logger.warning(f" β Continuation {i+1} appears repetitive β stopping") |
|
|
break |
|
|
|
|
|
|
|
|
accumulated = accumulated.rstrip() + '\n\n' + cont_text |
|
|
|
|
|
current_word_count = len(accumulated.split()) |
|
|
logger.info(f" β After continuation {i+1}, words={current_word_count}") |
|
|
|
|
|
|
|
|
if current_word_count >= target_min_words: |
|
|
break |
|
|
|
|
|
|
|
|
if current_word_count == prev_word_count: |
|
|
logger.warning(" β No progress from continuation β stopping") |
|
|
break |
|
|
prev_word_count = current_word_count |
|
|
|
|
|
final_words = accumulated.split() |
|
|
final_count = len(final_words) |
|
|
|
|
|
if final_count < target_min_words: |
|
|
logger.warning(f" β Final answer too short ({final_count} words) after continuations") |
|
|
return None |
|
|
|
|
|
if final_count > target_max_words: |
|
|
logger.info(f" β οΈ Final answer long ({final_count} words). Truncating to {target_max_words} words.") |
|
|
accumulated = ' '.join(final_words[:target_max_words]) + '...' |
|
|
final_count = target_max_words |
|
|
|
|
|
|
|
|
apology_phrases = ["i cannot", "i can't", "i'm sorry", "i apologize", "i don't have"] |
|
|
if any(phrase in accumulated.lower()[:200] for phrase in apology_phrases): |
|
|
logger.warning(" β Apology/hedging detected in final answer") |
|
|
return None |
|
|
|
|
|
logger.info(f" β
Built long-form answer ({final_count} words)") |
|
|
return accumulated |
|
|
|
|
|
def generate_answer_langchain( |
|
|
query: str, |
|
|
vectorstore, |
|
|
llm_client |
|
|
) -> str: |
|
|
logger.info(f"\n{'='*80}") |
|
|
logger.info(f"Processing query: '{query}'") |
|
|
logger.info(f"{'='*80}") |
|
|
|
|
|
retrieved_docs, confidence = retrieve_knowledge_langchain( |
|
|
query, |
|
|
vectorstore, |
|
|
top_k=CONFIG["top_k"] |
|
|
) |
|
|
|
|
|
if not retrieved_docs: |
|
|
return "I couldn't find relevant information to answer your question." |
|
|
|
|
|
llm_answer = None |
|
|
for attempt in range(1, 3): |
|
|
logger.info(f"\n π€ LLM Generation Attempt {attempt}/2") |
|
|
llm_answer = generate_llm_answer(query, retrieved_docs, llm_client, attempt) |
|
|
|
|
|
if llm_answer: |
|
|
logger.info(f" β
LLM answer generated successfully") |
|
|
break |
|
|
else: |
|
|
logger.warning(f" β Attempt {attempt}/2 failed, retrying...") |
|
|
|
|
|
if not llm_answer: |
|
|
logger.error(f" β All 2 LLM attempts failed") |
|
|
|
|
|
|
|
|
|
|
|
if USE_REMOTE_LLM: |
|
|
try: |
|
|
logger.info(" β Attempting scaffold-and-polish using remote LLM") |
|
|
polished = scaffold_and_polish(query, retrieved_docs, llm_client) |
|
|
if polished: |
|
|
logger.info(" β
Scaffold-and-polish produced an answer") |
|
|
return polished |
|
|
except Exception as e: |
|
|
logger.error(f" β Scaffold-and-polish error: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info(" β Using extractive fallback generator") |
|
|
fallback = generate_extractive_answer(query, retrieved_docs) |
|
|
if fallback: |
|
|
logger.info(" β
Extractive fallback produced an answer") |
|
|
return fallback |
|
|
except Exception as e: |
|
|
logger.error(f" β Extractive fallback error: {e}") |
|
|
|
|
|
return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else." |
|
|
|
|
|
return llm_answer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fashion_chatbot(message: str, history: List[List[str]]): |
|
|
try: |
|
|
if not message or not message.strip(): |
|
|
yield "Please ask a fashion-related question!" |
|
|
return |
|
|
|
|
|
yield "π Searching fashion knowledge..." |
|
|
|
|
|
retrieved_docs, confidence = retrieve_knowledge_langchain( |
|
|
message.strip(), |
|
|
vectorstore, |
|
|
top_k=CONFIG["top_k"] |
|
|
) |
|
|
|
|
|
if not retrieved_docs: |
|
|
yield "I couldn't find relevant information to answer your question." |
|
|
return |
|
|
|
|
|
yield f"π Generating answer ({len(retrieved_docs)} sources found)..." |
|
|
|
|
|
llm_answer = None |
|
|
for attempt in range(1, 3): |
|
|
logger.info(f"\n π€ LLM Generation Attempt {attempt}/2") |
|
|
llm_answer = generate_llm_answer(message.strip(), retrieved_docs, llm_client, attempt) |
|
|
|
|
|
if llm_answer: |
|
|
break |
|
|
|
|
|
if not llm_answer: |
|
|
logger.error(f" β All LLM attempts failed") |
|
|
yield "I apologize, but I'm having trouble generating a response. Please try rephrasing your question." |
|
|
return |
|
|
|
|
|
import time |
|
|
words = llm_answer.split() |
|
|
displayed_text = "" |
|
|
|
|
|
for i, word in enumerate(words): |
|
|
displayed_text += word + " " |
|
|
|
|
|
if i % 3 == 0 or i == len(words) - 1: |
|
|
yield displayed_text.strip() |
|
|
time.sleep(0.05) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in chatbot: {e}") |
|
|
yield f"Sorry, I encountered an error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm_client = None |
|
|
embeddings = None |
|
|
vectorstore = None |
|
|
|
|
|
def startup(): |
|
|
global llm_client, embeddings, vectorstore |
|
|
|
|
|
logger.info("π Starting Fashion Advisor RAG...") |
|
|
|
|
|
embeddings = initialize_embeddings() |
|
|
vectorstore = load_vector_store(embeddings) |
|
|
llm_client = initialize_llm() |
|
|
|
|
|
logger.info("β
All components initialized successfully!") |
|
|
|
|
|
startup() |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=fashion_chatbot, |
|
|
title="π Fashion Advisor - RAG System", |
|
|
description=""" |
|
|
**Ask me anything about fashion!** π |
|
|
|
|
|
I can help with: |
|
|
- Outfit recommendations for occasions |
|
|
- Color combinations and styling |
|
|
- Seasonal fashion advice |
|
|
- Body type and fit guidance |
|
|
- Wardrobe essentials |
|
|
|
|
|
*Powered by RAG with FAISS vector search and local LLM* |
|
|
""", |
|
|
examples=[ |
|
|
"What should I wear to a business meeting?", |
|
|
"What colors go well with navy blue?", |
|
|
"What are essential wardrobe items for fall?", |
|
|
"How to dress for a summer wedding?", |
|
|
"What's the best outfit for a university presentation?", |
|
|
], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|