ZedLow's picture
Update rag/pipeline.py
8132eba verified
import torch
import torch.nn.functional as F
from PIL import Image
from typing import List, Dict, Any, Callable, Tuple
from rag.prompting import build_messages
from rag.config import Settings
from rag.logging_utils import get_logger
from rag.utils import last_token_pool
from qwen_vl_utils import process_vision_info
logger = get_logger(__name__)
def _route_companies(
query: str,
router_model,
settings: Settings,
) -> Tuple[List[str], str | None]:
allowed_companies = {
"apple": "Apple",
"aapl": "Apple",
"microsoft": "Microsoft",
"msft": "Microsoft"
}
try:
labels = list(settings.router_labels)
entities = router_model.predict_entities(query, labels, threshold=settings.router_threshold)
except Exception as e:
print(f"Warning: GLiNER failed with error {e}, falling back to keywords.")
entities = []
detected_targets = []
unsupported_targets = []
for e in entities:
text = e.get("text", "")
name_clean = text.lower().strip().strip(".,?!'s")
found_match = False
for key, canonical_name in allowed_companies.items():
if key in name_clean:
detected_targets.append(canonical_name)
found_match = True
break
if not found_match:
unsupported_targets.append(text)
if not detected_targets:
query_lower = query.lower()
for key, canonical_name in allowed_companies.items():
if key in query_lower:
detected_targets.append(canonical_name)
detected_targets = list(set(detected_targets))
if detected_targets:
return detected_targets, None
if unsupported_targets:
unique_unsupported = list(set(unsupported_targets))
return [], (
f"⛔ **Out of Scope:** I detected a request for **{', '.join(unique_unsupported)}**. "
"This system only has access to **Microsoft** and **Apple** data."
)
# Cas C : Le désert total (vraiment rien trouvé)
return [], (
"❓ **Ambiguous Query:** I could not identify a specific company (Apple or Microsoft). "
"Please name the company explicitly in your question."
)
def _filter_docs(
dataset: List[Dict[str, Any]],
detected_companies: List[str],
) -> List[Dict[str, Any]]:
valid_docs = []
for i, doc in enumerate(dataset):
doc_name = doc.get("doc_name", "Doc")
if detected_companies:
if not any(company in doc_name for company in detected_companies):
continue
text = (doc.get("text") or "").strip()
if text:
valid_docs.append({"text": text, "original_index": i, "doc_name": doc_name})
return valid_docs
def _prepare_images(
dataset: List[Dict[str, Any]],
valid_docs: List[Dict[str, Any]],
top_k_indices: List[int],
r_scores,
top_k_indices_local: List[int],
):
images_content = []
gallery_preview = []
meta_info = ""
for idx_local in top_k_indices_local:
idx_in_valid = top_k_indices[idx_local]
final_doc_idx = valid_docs[idx_in_valid]["original_index"]
doc = dataset[final_doc_idx]
image_path = doc["image_path"]
score = r_scores[idx_local].item()
doc_name = doc.get("doc_name", "Unknown")
try:
img = Image.open(image_path)
header_text = f"SOURCE DOCUMENT: {doc_name} (Confidence: {score:.2f})\n"
images_content.append({"type": "text", "text": header_text})
images_content.append({"type": "image", "image": img})
gallery_preview.append((img, doc_name))
meta_info += f"- **{doc_name}** (Score: {score:.2f})\n"
except Exception as e:
logger.warning("Failed to open image %s: %s", image_path, e)
continue
return images_content, gallery_preview, meta_info
def make_retrieve_and_answer(
dataset: List[Dict[str, Any]],
models,
settings: Settings | None = None,
) -> Callable[[str], tuple]:
if settings is None:
settings = models.settings if hasattr(models, "settings") else Settings()
import spaces
@spaces.GPU
def retrieve_and_answer(query: str):
logger.info("User question: %s", query)
if not dataset:
return [], "Empty corpus", "No documents loaded."
detected_companies, blocked_msg = _route_companies(query, models.router_model, settings)
if blocked_msg is not None:
return [], "", blocked_msg
logger.info("Router detected companies: %s", detected_companies)
valid_docs = _filter_docs(dataset, detected_companies)
if not valid_docs:
return [], "", "System Error: Valid targets detected but no matching documents found."
query_text = (
"Instruct: Given a user query, retrieve relevant passages that answer the query.\n"
f"Query: {query}"
)
with torch.no_grad():
q_inputs = models.embed_tokenizer(
[query_text],
max_length=settings.embed_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).to(models.embed_model.device)
q_outputs = models.embed_model(**q_inputs)
q_emb = last_token_pool(q_outputs.last_hidden_state, q_inputs["attention_mask"])
q_emb = F.normalize(q_emb, p=2, dim=1)
d_embeddings_list = []
doc_texts = [d["text"] for d in valid_docs]
for i in range(0, len(doc_texts), 1):
d_inputs = models.embed_tokenizer(
doc_texts[i:i + 1],
max_length=settings.embed_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).to(models.embed_model.device)
d_outputs = models.embed_model(**d_inputs)
batch_emb = last_token_pool(d_outputs.last_hidden_state, d_inputs["attention_mask"])
batch_emb = F.normalize(batch_emb, p=2, dim=1)
d_embeddings_list.append(batch_emb)
d_emb_final = torch.cat(d_embeddings_list, dim=0)
scores = (q_emb @ d_emb_final.T).squeeze(0)
k_val = min(settings.embed_top_k, len(scores))
top_k_indices = torch.topk(scores, k=k_val).indices.tolist()
pairs = [[query, valid_docs[idx]["text"]] for idx in top_k_indices]
with torch.no_grad():
r_inputs = models.rerank_tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=settings.rerank_max_length,
).to(models.rerank_model.device)
r_scores = models.rerank_model(**r_inputs, return_dict=True).logits.view(-1).float()
k_rerank = min(settings.rerank_top_k, len(r_scores))
top_k_indices_local = torch.topk(r_scores, k=k_rerank).indices.tolist()
meta_info = f"**AI Router Focus:** {', '.join(detected_companies)}\n\n"
images_content, gallery_preview, meta_sources = _prepare_images(
dataset, valid_docs, top_k_indices, r_scores, top_k_indices_local
)
meta_info += meta_sources
if not images_content:
return [], "", "No images found for the retrieved passages."
messages = build_messages(query, images_content)
text_input = models.gen_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, _video_inputs = process_vision_info(messages)
inputs = models.gen_processor(
text=[text_input],
images=image_inputs,
padding=True,
return_tensors="pt",
).to(models.gen_model.device)
generated_ids = models.gen_model.generate(**inputs, max_new_tokens=settings.max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = models.gen_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return gallery_preview, meta_info, response
return retrieve_and_answer