|
|
""" |
|
|
Fashion Advisor RAG - Hugging Face Deployment |
|
|
Complete RAG system with FAISS vector store and local LLM |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import List, Tuple, Dict, Optional |
|
|
import pickle |
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import pipeline |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import requests |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.schema import Document |
|
|
|
|
|
|
|
|
import os |
|
|
os.environ['TRANSFORMERS_VERBOSITY'] = 'error' |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
torch.set_num_threads(4) |
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore", message="MatMul8bitLt") |
|
|
warnings.filterwarnings("ignore", message="torch_dtype") |
|
|
warnings.filterwarnings("ignore", message="Invalid file descriptor") |
|
|
warnings.filterwarnings("ignore", message="generation flags") |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
import asyncio |
|
|
import sys |
|
|
if sys.platform == 'linux': |
|
|
try: |
|
|
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG = { |
|
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2", |
|
|
"llm_model": None, |
|
|
"vector_store_path": ".", |
|
|
"top_k": 12, |
|
|
"temperature": 0.75, |
|
|
"max_tokens": 600, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
LOCAL_LLM_MODEL = os.environ.get("LOCAL_LLM_MODEL", "distilgpt2") |
|
|
USE_8BIT_QUANTIZATION = False |
|
|
USE_REMOTE_LLM = False |
|
|
|
|
|
|
|
|
MAX_CONTEXT_LENGTH = 400 |
|
|
USE_CACHING = True |
|
|
ENABLE_FAST_MODE = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_INFERENCE_API_KEY = os.environ.get("HF_INFERENCE_API_KEY") |
|
|
if not HF_INFERENCE_API_KEY: |
|
|
try: |
|
|
token_path = Path("hf_token.txt") |
|
|
if token_path.exists(): |
|
|
HF_INFERENCE_API_KEY = token_path.read_text(encoding="utf-8").strip() |
|
|
logger.info("Loaded HF token from hf_token.txt (ensure this file is private and not committed)") |
|
|
except Exception: |
|
|
logger.warning("Could not read hf_token.txt for HF token") |
|
|
|
|
|
if HF_INFERENCE_API_KEY: |
|
|
USE_REMOTE_LLM = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_llm(): |
|
|
"""Initialize DistilGPT2 for local CPU generation. |
|
|
|
|
|
DistilGPT2 is lightweight (82M params), fast, and has no special dependencies. |
|
|
""" |
|
|
global LOCAL_LLM_MODEL |
|
|
|
|
|
logger.info(f"π Initializing DistilGPT2: {LOCAL_LLM_MODEL}") |
|
|
logger.info(" Lightweight and fast on CPU") |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f" Device: {device}") |
|
|
|
|
|
|
|
|
logger.info(" Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(LOCAL_LLM_MODEL) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
logger.info(" Tokenizer ready") |
|
|
|
|
|
|
|
|
logger.info(" Loading DistilGPT2 (5-10 seconds)...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
LOCAL_LLM_MODEL, |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
logger.info(" Model ready") |
|
|
|
|
|
|
|
|
logger.info(" Creating generation pipeline...") |
|
|
llm_client = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=0 if device == "cuda" else -1, |
|
|
max_new_tokens=100 |
|
|
) |
|
|
|
|
|
CONFIG["llm_model"] = LOCAL_LLM_MODEL |
|
|
CONFIG["model_type"] = "distilgpt2_local" |
|
|
|
|
|
logger.info(f"β
DistilGPT2 initialized: {LOCAL_LLM_MODEL}") |
|
|
logger.info(f" Size: 82M parameters (very lightweight)") |
|
|
logger.info(f" Speed: 2-5 seconds per response") |
|
|
|
|
|
return llm_client |
|
|
|
|
|
except ImportError as ie: |
|
|
logger.error(f"β Missing required library: {ie}") |
|
|
logger.info(" Install with: pip install transformers torch") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to load LLM: {str(e)}") |
|
|
logger.info(" This may be due to insufficient memory") |
|
|
import traceback |
|
|
logger.error(traceback.format_exc()) |
|
|
raise Exception(f"Failed to initialize LLM: {str(e)}") |
|
|
|
|
|
|
|
|
def remote_generate(prompt: str, max_new_tokens: int = 200, temperature: float = 0.7, top_p: float = 0.9) -> str: |
|
|
"""Call Hugging Face Inference API - fast and reliable. |
|
|
|
|
|
Uses Qwen2.5 model optimized for fast inference. |
|
|
""" |
|
|
if not HF_INFERENCE_API_KEY: |
|
|
raise Exception("HF_INFERENCE_API_KEY not set for remote generation") |
|
|
|
|
|
|
|
|
api_url = f"https://api-inference.huggingface.co/models/{REMOTE_LLM_MODEL}" |
|
|
headers = {"Authorization": f"Bearer {HF_INFERENCE_API_KEY}"} |
|
|
|
|
|
|
|
|
payload = { |
|
|
"inputs": prompt, |
|
|
"parameters": { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"return_full_text": False |
|
|
} |
|
|
} |
|
|
|
|
|
logger.info(f" β Remote inference (tokens={max_new_tokens})") |
|
|
try: |
|
|
r = requests.post(api_url, headers=headers, json=payload, timeout=90) |
|
|
except Exception as e: |
|
|
logger.error(f" β Remote request failed: {e}") |
|
|
return "" |
|
|
|
|
|
if r.status_code == 503: |
|
|
logger.warning(f" β οΈ Model loading (503), retrying in 5s...") |
|
|
import time |
|
|
time.sleep(5) |
|
|
try: |
|
|
r = requests.post(api_url, headers=headers, json=payload, timeout=90) |
|
|
except Exception as e: |
|
|
logger.error(f" β Retry failed: {e}") |
|
|
return "" |
|
|
|
|
|
if r.status_code != 200: |
|
|
logger.error(f" β Remote inference error {r.status_code}: {r.text[:300]}") |
|
|
return "" |
|
|
|
|
|
result = r.json() |
|
|
|
|
|
|
|
|
if isinstance(result, dict) and result.get("error"): |
|
|
logger.error(f" β Remote inference returned error: {result.get('error')}") |
|
|
return "" |
|
|
|
|
|
|
|
|
generated_text = "" |
|
|
|
|
|
if isinstance(result, list) and result: |
|
|
first = result[0] |
|
|
if isinstance(first, dict): |
|
|
generated_text = first.get("generated_text", "") |
|
|
else: |
|
|
generated_text = str(first) |
|
|
elif isinstance(result, dict): |
|
|
generated_text = result.get("generated_text", str(result)) |
|
|
else: |
|
|
generated_text = str(result) |
|
|
|
|
|
|
|
|
generated_text = generated_text.strip() |
|
|
if prompt in generated_text: |
|
|
generated_text = generated_text.replace(prompt, "").strip() |
|
|
|
|
|
logger.info(f" β
Generated {len(generated_text.split())} words remotely") |
|
|
return generated_text |
|
|
|
|
|
def initialize_embeddings(): |
|
|
logger.info("π Initializing embeddings model...") |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name=CONFIG["embedding_model"], |
|
|
model_kwargs={'device': 'cpu'}, |
|
|
encode_kwargs={'normalize_embeddings': True} |
|
|
) |
|
|
|
|
|
logger.info(f"β
Embeddings initialized: {CONFIG['embedding_model']}") |
|
|
return embeddings |
|
|
|
|
|
def load_vector_store(embeddings): |
|
|
logger.info("π Loading FAISS vector store...") |
|
|
|
|
|
vector_store_path = CONFIG["vector_store_path"] |
|
|
index_file = os.path.join(vector_store_path, "index.faiss") |
|
|
pkl_file = os.path.join(vector_store_path, "index.pkl") |
|
|
|
|
|
if not os.path.exists(index_file): |
|
|
raise FileNotFoundError(f"FAISS index file not found: {index_file}") |
|
|
|
|
|
if not os.path.exists(pkl_file): |
|
|
raise FileNotFoundError(f"FAISS metadata file not found: {pkl_file}") |
|
|
|
|
|
logger.info(f"β
Found index.faiss ({os.path.getsize(index_file)/1024/1024:.2f} MB)") |
|
|
logger.info(f"β
Found index.pkl ({os.path.getsize(pkl_file)/1024:.2f} KB)") |
|
|
|
|
|
try: |
|
|
vectorstore = FAISS.load_local( |
|
|
vector_store_path, |
|
|
embeddings, |
|
|
allow_dangerous_deserialization=True |
|
|
) |
|
|
logger.info(f"β
FAISS vector store loaded successfully") |
|
|
return vectorstore |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ Pydantic compatibility issue: {str(e)[:100]}") |
|
|
logger.info("π Applying Pydantic monkey-patch and retrying...") |
|
|
|
|
|
try: |
|
|
import pydantic.v1.main as pydantic_main |
|
|
original_setstate = pydantic_main.BaseModel.__setstate__ |
|
|
|
|
|
def patched_setstate(self, state): |
|
|
if '__fields_set__' not in state: |
|
|
state['__fields_set__'] = set(state.get('__dict__', {}).keys()) |
|
|
return original_setstate(self, state) |
|
|
|
|
|
pydantic_main.BaseModel.__setstate__ = patched_setstate |
|
|
logger.info(" β
Pydantic monkey-patch applied") |
|
|
|
|
|
except Exception as patch_error: |
|
|
logger.warning(f" β οΈ Pydantic patch failed: {patch_error}") |
|
|
|
|
|
try: |
|
|
vectorstore = FAISS.load_local( |
|
|
vector_store_path, |
|
|
embeddings, |
|
|
allow_dangerous_deserialization=True |
|
|
) |
|
|
logger.info(f"β
FAISS vector store loaded with Pydantic patch") |
|
|
return vectorstore |
|
|
|
|
|
except Exception as e2: |
|
|
logger.error(f" β Still failed after patch: {str(e2)[:100]}") |
|
|
logger.info("π Using manual reconstruction (last resort)...") |
|
|
|
|
|
import faiss |
|
|
from langchain_community.docstore.in_memory import InMemoryDocstore |
|
|
|
|
|
index = faiss.read_index(index_file) |
|
|
logger.info(f" β
FAISS index loaded") |
|
|
|
|
|
with open(pkl_file, "rb") as f: |
|
|
import re |
|
|
raw_bytes = f.read() |
|
|
logger.info(f" Read {len(raw_bytes)} bytes from pickle") |
|
|
|
|
|
text_pattern = rb'([A-Za-z0-9\s\.\,\;\:\!\?\-\'\"\(\)]{50,})' |
|
|
matches = re.findall(text_pattern, raw_bytes) |
|
|
|
|
|
if len(matches) > 100: |
|
|
logger.info(f" Found {len(matches)} potential document fragments") |
|
|
|
|
|
documents = [] |
|
|
for idx, match in enumerate(matches[:5000]): |
|
|
try: |
|
|
content = match.decode('utf-8', errors='ignore').strip() |
|
|
if len(content) >= 100: |
|
|
doc = Document( |
|
|
page_content=content, |
|
|
metadata={"source": "reconstructed", "id": idx} |
|
|
) |
|
|
documents.append(doc) |
|
|
except: |
|
|
continue |
|
|
|
|
|
if len(documents) < 100: |
|
|
raise Exception(f"Only extracted {len(documents)} documents, need at least 100") |
|
|
|
|
|
logger.info(f" β
Extracted {len(documents)} high-quality documents") |
|
|
logger.info(f" π Rebuilding FAISS index from scratch...") |
|
|
|
|
|
vectorstore = FAISS.from_documents( |
|
|
documents=documents, |
|
|
embedding=embeddings |
|
|
) |
|
|
|
|
|
logger.info(f"β
FAISS vector store rebuilt from {len(documents)} documents") |
|
|
return vectorstore |
|
|
else: |
|
|
raise Exception("Could not extract enough document content from pickle") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_extractive_answer(query: str, retrieved_docs: List[Document]) -> Optional[str]: |
|
|
"""Build a long-form answer from retrieved documents using extractive |
|
|
selection + templated transitions. This avoids calling the LLM when it |
|
|
repeatedly fails or returns very short outputs. |
|
|
""" |
|
|
logger.info(f"π§ Running extractive fallback for: '{query}'") |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
all_text = "\n\n".join([d.page_content for d in retrieved_docs]) |
|
|
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', all_text) |
|
|
sentences = [s.strip() for s in sentences if len(s.strip()) > 30] |
|
|
|
|
|
if not sentences: |
|
|
logger.warning(" β No sentences found in retrieved documents for extractive fallback") |
|
|
return None |
|
|
|
|
|
|
|
|
query_tokens = set(re.findall(r"\w+", query.lower())) |
|
|
fashion_keywords = set(["outfit","wear","wardrobe","style","colors","color","layer","layering", |
|
|
"blazer","trousers","dress","shirt","shoes","boots","sweater","jacket", |
|
|
"care","wash","dry","clean","wool","cotton","silk","linen","fit","tailor", |
|
|
"versatile","neutral","accessory","belt","bag","occasion","season","fall"]) |
|
|
keywords = query_tokens.union(fashion_keywords) |
|
|
|
|
|
scored = [] |
|
|
for s in sentences: |
|
|
s_tokens = set(re.findall(r"\w+", s.lower())) |
|
|
score = len(s_tokens & keywords) |
|
|
|
|
|
score += min(3, len(s.split()) // 20) |
|
|
scored.append((score, s)) |
|
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
top_sentences = [s for _, s in scored[:60]] |
|
|
|
|
|
|
|
|
def pick(n, start=0): |
|
|
return top_sentences[start:start+n] |
|
|
|
|
|
intro = [] |
|
|
intro.extend(pick(2, 0)) |
|
|
key_items = pick(8, 2) |
|
|
styling = pick(8, 10) |
|
|
care = pick(6, 18) |
|
|
conclusion = pick(4, 24) |
|
|
|
|
|
|
|
|
template_intro = f"Here's a detailed answer to '{query}'. I'll cover essential wardrobe items, styling tips, and care advice so you can apply these suggestions practically." |
|
|
|
|
|
|
|
|
care_text = "\n\n".join(care) |
|
|
if "dry clean" not in care_text.lower() and "hand wash" not in care_text.lower(): |
|
|
care_text += "\n\nDry clean or hand wash in cold water with wool-specific detergent. Never wring out wool - gently squeeze excess water and lay flat to dry on a towel." |
|
|
|
|
|
parts = [] |
|
|
parts.append(template_intro) |
|
|
if intro: |
|
|
parts.append(" ".join(intro)) |
|
|
if key_items: |
|
|
parts.append("Key wardrobe items to prioritize:") |
|
|
parts.append(" ".join(key_items)) |
|
|
if styling: |
|
|
parts.append("Practical styling tips:") |
|
|
parts.append(" ".join(styling)) |
|
|
if care_text: |
|
|
parts.append("Care & maintenance:") |
|
|
parts.append(care_text) |
|
|
if conclusion: |
|
|
parts.append("Wrapping up:") |
|
|
parts.append(" ".join(conclusion)) |
|
|
|
|
|
|
|
|
answer = "\n\n".join(parts) |
|
|
|
|
|
|
|
|
words = answer.split() |
|
|
word_count = len(words) |
|
|
|
|
|
logger.info(f" β
Extractive answer ready ({word_count} words)") |
|
|
return answer |
|
|
|
|
|
|
|
|
def scaffold_and_polish(query: str, retrieved_docs: List[Document], llm_client) -> Optional[str]: |
|
|
"""Create a concise scaffold (approx 150-220 words) from retrieved docs, |
|
|
then ask the remote (or local) LLM to expand and polish it into a |
|
|
320-420 word expert answer. Returns None if polishing fails. |
|
|
""" |
|
|
logger.info(f"π¨ Building scaffold for polish: '{query}'") |
|
|
import re |
|
|
|
|
|
|
|
|
all_text = "\n\n".join([d.page_content for d in retrieved_docs[:12]]) |
|
|
sentences = re.split(r'(?<=[.!?])\s+', all_text) |
|
|
sentences = [s.strip() for s in sentences if len(s.strip()) > 30] |
|
|
if not sentences: |
|
|
logger.warning(" β No sentences to build scaffold") |
|
|
return None |
|
|
|
|
|
|
|
|
query_tokens = set(re.findall(r"\w+", query.lower())) |
|
|
fashion_keywords = set(["outfit","wear","wardrobe","style","colors","layer","blazer", |
|
|
"trousers","dress","shoes","sweater","jacket","care","wool","fit", |
|
|
"tailor","neutral","accessory","season","fall"]) |
|
|
keywords = query_tokens.union(fashion_keywords) |
|
|
|
|
|
scored = [] |
|
|
for s in sentences: |
|
|
s_tokens = set(re.findall(r"\w+", s.lower())) |
|
|
score = len(s_tokens & keywords) |
|
|
score += min(2, len(s.split()) // 30) |
|
|
scored.append((score, s)) |
|
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
scaffold_parts = [] |
|
|
word_count = 0 |
|
|
for _, s in scored: |
|
|
scaffold_parts.append(s) |
|
|
word_count = len(" ".join(scaffold_parts).split()) |
|
|
if word_count >= 180: |
|
|
break |
|
|
|
|
|
scaffold = "\n\n".join(scaffold_parts).strip() |
|
|
if not scaffold: |
|
|
logger.warning(" β Scaffold empty after selection") |
|
|
return None |
|
|
|
|
|
|
|
|
polish_prompt = f"""Expand this draft into a complete, detailed fashion answer for: {query} |
|
|
|
|
|
Draft: {scaffold} |
|
|
|
|
|
Write a comprehensive, natural answer with practical advice and specific recommendations. |
|
|
|
|
|
Enhanced answer: |
|
|
""" |
|
|
|
|
|
logger.info(" β Polishing scaffold with PHI model") |
|
|
try: |
|
|
out = llm_client( |
|
|
polish_prompt, |
|
|
max_new_tokens=600, |
|
|
temperature=0.75, |
|
|
top_p=0.92, |
|
|
do_sample=True, |
|
|
repetition_penalty=1.1, |
|
|
pad_token_id=llm_client.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(out, list) and out: |
|
|
polished = out[0].get('generated_text', '') if isinstance(out[0], dict) else str(out[0]) |
|
|
else: |
|
|
polished = str(out) |
|
|
|
|
|
|
|
|
if polish_prompt in polished: |
|
|
polished = polished[len(polish_prompt):].strip() |
|
|
else: |
|
|
polished = polished.strip() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" β Polishing error: {e}") |
|
|
return None |
|
|
|
|
|
if not polished: |
|
|
logger.warning(" β Polished output empty") |
|
|
return None |
|
|
|
|
|
final_words = polished.split() |
|
|
fw = len(final_words) |
|
|
|
|
|
|
|
|
if fw < 50: |
|
|
logger.warning(f" β Polished output too short ({fw} words)") |
|
|
return None |
|
|
|
|
|
|
|
|
logger.info(f" β
Polished answer ready ({fw} words)") |
|
|
return polished |
|
|
|
|
|
|
|
|
def retrieve_knowledge_langchain( |
|
|
query: str, |
|
|
vectorstore, |
|
|
top_k: int = 12 |
|
|
) -> Tuple[List[Document], float]: |
|
|
logger.info(f"π Retrieving knowledge for: '{query}'") |
|
|
|
|
|
|
|
|
query_variants = [ |
|
|
query, |
|
|
f"fashion advice clothing outfit style for {query}", |
|
|
] |
|
|
|
|
|
all_docs = [] |
|
|
|
|
|
for variant in query_variants: |
|
|
try: |
|
|
docs_and_scores = vectorstore.similarity_search_with_score(variant, k=top_k) |
|
|
|
|
|
for doc, score in docs_and_scores: |
|
|
similarity = 1.0 / (1.0 + score) |
|
|
doc.metadata['similarity'] = similarity |
|
|
doc.metadata['query_variant'] = variant |
|
|
all_docs.append(doc) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Retrieval error for variant '{variant}': {e}") |
|
|
|
|
|
unique_docs = {} |
|
|
for doc in all_docs: |
|
|
content_key = doc.page_content[:100] |
|
|
if content_key not in unique_docs: |
|
|
unique_docs[content_key] = doc |
|
|
else: |
|
|
if doc.metadata.get('similarity', 0) > unique_docs[content_key].metadata.get('similarity', 0): |
|
|
unique_docs[content_key] = doc |
|
|
|
|
|
final_docs = list(unique_docs.values()) |
|
|
final_docs.sort(key=lambda x: x.metadata.get('similarity', 0), reverse=True) |
|
|
|
|
|
if final_docs: |
|
|
avg_similarity = sum(d.metadata.get('similarity', 0) for d in final_docs) / len(final_docs) |
|
|
confidence = min(avg_similarity, 1.0) |
|
|
else: |
|
|
confidence = 0.0 |
|
|
|
|
|
logger.info(f"β
Retrieved {len(final_docs)} unique documents (confidence: {confidence:.2f})") |
|
|
|
|
|
return final_docs, confidence |
|
|
|
|
|
def generate_llm_answer( |
|
|
query: str, |
|
|
retrieved_docs: List[Document], |
|
|
llm_client, |
|
|
attempt: int = 1 |
|
|
) -> Optional[str]: |
|
|
|
|
|
if not llm_client: |
|
|
logger.error(" β PHI model not initialized") |
|
|
return None |
|
|
|
|
|
query_lower = query.lower() |
|
|
query_words = set(query_lower.split()) |
|
|
|
|
|
scored_docs = [] |
|
|
for doc in retrieved_docs[:20]: |
|
|
content = doc.page_content.lower() |
|
|
doc_words = set(content.split()) |
|
|
overlap = len(query_words.intersection(doc_words)) |
|
|
|
|
|
if doc.metadata.get('verified', False): |
|
|
overlap += 10 |
|
|
|
|
|
if len(doc.page_content) > 200: |
|
|
overlap += 3 |
|
|
|
|
|
scored_docs.append((doc, overlap)) |
|
|
|
|
|
scored_docs.sort(key=lambda x: x[1], reverse=True) |
|
|
top_docs = [doc[0] for doc in scored_docs[:8]] |
|
|
|
|
|
|
|
|
context_parts = [] |
|
|
for doc in top_docs[:3]: |
|
|
content = doc.page_content.strip() |
|
|
if len(content) > 200: |
|
|
content = content[:200] + "..." |
|
|
context_parts.append(content) |
|
|
|
|
|
context_text = "\n\n".join(context_parts) |
|
|
|
|
|
|
|
|
target_min_words = 100 |
|
|
target_max_words = 999999 |
|
|
chunk_target_words = 0 |
|
|
max_iterations = 0 |
|
|
|
|
|
def call_model(prompt, max_new_tokens, temperature): |
|
|
"""Generate with DistilGPT2""" |
|
|
try: |
|
|
|
|
|
logger.info(f" β Generating (max_tokens={max_new_tokens})") |
|
|
|
|
|
out = llm_client( |
|
|
prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
return_full_text=False, |
|
|
repetition_penalty=1.3, |
|
|
no_repeat_ngram_size=2, |
|
|
top_k=40, |
|
|
top_p=0.9, |
|
|
pad_token_id=llm_client.tokenizer.eos_token_id, |
|
|
eos_token_id=llm_client.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
if not out or not isinstance(out, list) or len(out) == 0: |
|
|
return '' |
|
|
|
|
|
generated = out[0].get('generated_text', '').strip() |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
generated = re.sub(r'\b[A-Z]:\s*(?=[A-Z]:)', '', generated) |
|
|
generated = re.sub(r'^[A-Z]:\s*', '', generated) |
|
|
generated = generated.strip() |
|
|
|
|
|
word_count = len(generated.split()) |
|
|
logger.info(f" β
Generated {word_count} words") |
|
|
return generated |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" β Error: {e}") |
|
|
return '' |
|
|
|
|
|
|
|
|
base_prompt = f"""For the question "{query}", here is helpful fashion advice: |
|
|
|
|
|
{context_text[:300]} |
|
|
|
|
|
To summarize:""" |
|
|
|
|
|
|
|
|
if attempt == 1: |
|
|
max_new_tokens = 120 |
|
|
temperature = 0.6 |
|
|
else: |
|
|
max_new_tokens = 150 |
|
|
temperature = 0.65 |
|
|
|
|
|
logger.info(f" β Starting generation with prompt: {base_prompt[:200]}...") |
|
|
initial_output = call_model(base_prompt, max_new_tokens, temperature) |
|
|
response = (initial_output or '').strip() |
|
|
|
|
|
|
|
|
if not response: |
|
|
logger.warning(" β Empty initial response - model may not be generating") |
|
|
logger.warning(f" β Prompt was: {base_prompt[:300]}") |
|
|
response = '' |
|
|
|
|
|
words = response.split() |
|
|
word_count = len(words) |
|
|
|
|
|
logger.info(f" β Initial response: {word_count} words") |
|
|
|
|
|
|
|
|
|
|
|
if word_count >= target_min_words: |
|
|
|
|
|
logger.info(f" β
Generated {word_count} words naturally") |
|
|
return response |
|
|
|
|
|
|
|
|
if word_count >= 50: |
|
|
logger.info(f" β
Accepted natural response ({word_count} words)") |
|
|
return response |
|
|
|
|
|
|
|
|
if word_count >= 20: |
|
|
logger.info(f" β οΈ Short but acceptable response ({word_count} words)") |
|
|
return response |
|
|
|
|
|
|
|
|
if word_count >= 10: |
|
|
logger.info(f" β οΈ Very short response ({word_count} words) but accepting") |
|
|
return response |
|
|
|
|
|
|
|
|
if word_count >= 5: |
|
|
logger.info(f" β οΈ EMERGENCY: Accepting tiny response ({word_count} words)") |
|
|
return response |
|
|
|
|
|
|
|
|
accumulated = response |
|
|
prev_word_count = word_count |
|
|
|
|
|
for i in range(max_iterations): |
|
|
remaining = max(0, target_min_words - len(accumulated.split())) |
|
|
if remaining <= 0: |
|
|
break |
|
|
|
|
|
|
|
|
continue_prompt = f"""Add {min(chunk_target_words, remaining)} more words to complete this answer: |
|
|
|
|
|
{accumulated[-400:]} |
|
|
|
|
|
Continue naturally: |
|
|
""" |
|
|
|
|
|
|
|
|
cont_output = call_model(continue_prompt, max_new_tokens=250, temperature=0.80, top_p=0.90, repetition_penalty=1.10) |
|
|
cont_text = (cont_output or '').strip() |
|
|
|
|
|
if not cont_text: |
|
|
logger.warning(f" β Continuation {i+1} returned empty β stopping") |
|
|
break |
|
|
|
|
|
|
|
|
if cont_text in accumulated or accumulated.endswith(cont_text[:50]): |
|
|
logger.warning(f" β Continuation {i+1} appears repetitive β stopping") |
|
|
break |
|
|
|
|
|
|
|
|
accumulated = accumulated.rstrip() + '\n\n' + cont_text |
|
|
|
|
|
current_word_count = len(accumulated.split()) |
|
|
logger.info(f" β After continuation {i+1}, words={current_word_count}") |
|
|
|
|
|
|
|
|
if current_word_count >= target_min_words: |
|
|
break |
|
|
|
|
|
|
|
|
if current_word_count == prev_word_count: |
|
|
logger.warning(" β No progress from continuation β stopping") |
|
|
break |
|
|
prev_word_count = current_word_count |
|
|
|
|
|
final_words = accumulated.split() |
|
|
final_count = len(final_words) |
|
|
|
|
|
if final_count < target_min_words: |
|
|
logger.warning(f" β Final answer too short ({final_count} words) after continuations") |
|
|
return None |
|
|
|
|
|
if final_count > target_max_words: |
|
|
logger.info(f" β οΈ Final answer long ({final_count} words). Truncating to {target_max_words} words.") |
|
|
accumulated = ' '.join(final_words[:target_max_words]) + '...' |
|
|
final_count = target_max_words |
|
|
|
|
|
|
|
|
apology_phrases = ["i cannot", "i can't", "i'm sorry", "i apologize", "i don't have"] |
|
|
if any(phrase in accumulated.lower()[:200] for phrase in apology_phrases): |
|
|
logger.warning(" β Apology/hedging detected in final answer") |
|
|
return None |
|
|
|
|
|
logger.info(f" β
Built long-form answer ({final_count} words)") |
|
|
return accumulated |
|
|
|
|
|
def generate_answer_langchain( |
|
|
query: str, |
|
|
vectorstore, |
|
|
llm_client |
|
|
) -> str: |
|
|
logger.info(f"\n{'='*80}") |
|
|
logger.info(f"Processing query: '{query}'") |
|
|
logger.info(f"{'='*80}") |
|
|
|
|
|
retrieved_docs, confidence = retrieve_knowledge_langchain( |
|
|
query, |
|
|
vectorstore, |
|
|
top_k=CONFIG["top_k"] |
|
|
) |
|
|
|
|
|
if not retrieved_docs: |
|
|
return "I couldn't find relevant information to answer your question." |
|
|
|
|
|
|
|
|
max_attempts = 2 |
|
|
|
|
|
llm_answer = None |
|
|
for attempt in range(1, max_attempts + 1): |
|
|
logger.info(f"\n π€ LLM Generation Attempt {attempt}/{max_attempts}") |
|
|
llm_answer = generate_llm_answer(query, retrieved_docs, llm_client, attempt) |
|
|
|
|
|
if llm_answer: |
|
|
logger.info(f" β
LLM answer generated successfully") |
|
|
return llm_answer |
|
|
else: |
|
|
if attempt < max_attempts: |
|
|
logger.warning(f" β Attempt {attempt}/{max_attempts} failed, retrying...") |
|
|
|
|
|
logger.error(f" β All {max_attempts} LLM attempts failed") |
|
|
return "I apologize, but I'm having trouble generating a response. Please try rephrasing your question or ask something else." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fashion_chatbot(message: str, history: List[List[str]]): |
|
|
try: |
|
|
if not message or not message.strip(): |
|
|
yield "Please ask a fashion-related question!" |
|
|
return |
|
|
|
|
|
yield "π Searching fashion knowledge..." |
|
|
|
|
|
retrieved_docs, confidence = retrieve_knowledge_langchain( |
|
|
message.strip(), |
|
|
vectorstore, |
|
|
top_k=CONFIG["top_k"] |
|
|
) |
|
|
|
|
|
if not retrieved_docs: |
|
|
yield "I couldn't find relevant information to answer your question." |
|
|
return |
|
|
|
|
|
yield f"π Generating answer ({len(retrieved_docs)} sources found)..." |
|
|
|
|
|
|
|
|
llm_answer = None |
|
|
for attempt in range(1, 3): |
|
|
logger.info(f"\n π€ LLM Generation Attempt {attempt}/2") |
|
|
llm_answer = generate_llm_answer(message.strip(), retrieved_docs, llm_client, attempt) |
|
|
|
|
|
if llm_answer: |
|
|
break |
|
|
|
|
|
if not llm_answer: |
|
|
logger.error(f" β All LLM attempts failed") |
|
|
yield "I apologize, but I'm having trouble generating a response. Please try rephrasing your question." |
|
|
return |
|
|
|
|
|
import time |
|
|
words = llm_answer.split() |
|
|
displayed_text = "" |
|
|
|
|
|
|
|
|
for i, word in enumerate(words): |
|
|
displayed_text += word + " " |
|
|
|
|
|
if i % 5 == 0 or i == len(words) - 1: |
|
|
yield displayed_text.strip() |
|
|
time.sleep(0.02) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in chatbot: {e}") |
|
|
yield f"Sorry, I encountered an error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llm_client = None |
|
|
embeddings = None |
|
|
vectorstore = None |
|
|
|
|
|
def startup(): |
|
|
global llm_client, embeddings, vectorstore |
|
|
|
|
|
logger.info("π Starting Fashion Advisor RAG...") |
|
|
|
|
|
embeddings = initialize_embeddings() |
|
|
vectorstore = load_vector_store(embeddings) |
|
|
llm_client = initialize_llm() |
|
|
|
|
|
logger.info("β
All components initialized successfully!") |
|
|
|
|
|
startup() |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=fashion_chatbot, |
|
|
title="π Fashion Advisor - RAG System", |
|
|
description=""" |
|
|
**Ask me anything about fashion!** π |
|
|
|
|
|
I can help with: |
|
|
- Outfit recommendations for occasions |
|
|
- Color combinations and styling |
|
|
- Seasonal fashion advice |
|
|
- Body type and fit guidance |
|
|
- Wardrobe essentials |
|
|
|
|
|
*Powered by RAG with FAISS vector search and local LLM* |
|
|
""", |
|
|
examples=[ |
|
|
"What should I wear to a business meeting?", |
|
|
"What colors go well with navy blue?", |
|
|
"What are essential wardrobe items for fall?", |
|
|
"How to dress for a summer wedding?", |
|
|
"What's the best outfit for a university presentation?", |
|
|
], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|