project / src /rag_pipeline.py
dnj0's picture
Upload 4 files
8099442 verified
from typing import List, Dict, Optional, Tuple
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info
from PIL import Image
import io
class TokenChunker:
"""Handle token counting and chunking for model context limits."""
def __init__(self, model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"):
"""Initialize tokenizer for token counting."""
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Qwen2.5-VL has max context of 131,072 tokens
self.max_tokens = 100000 # Conservative limit (use 100K of 131K available)
def count_tokens(self, text: str) -> int:
"""Count tokens in text."""
try:
tokens = self.tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
except Exception as e:
print(f"Error counting tokens: {e}")
# Rough estimate: 1 token ≈ 4 characters for English/Russian
return len(text) // 4
def chunk_text(self, text: str, chunk_size: int = 50000) -> List[str]:
"""Split text into chunks that fit within token limits."""
if len(text) <= chunk_size:
return [text]
chunks = []
current_chunk = ""
# Split by paragraphs first
paragraphs = text.split("\n\n")
for paragraph in paragraphs:
if len(current_chunk) + len(paragraph) < chunk_size:
current_chunk += paragraph + "\n\n"
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = paragraph + "\n\n"
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def truncate_to_token_limit(self, text: str, token_limit: int = 50000) -> str:
"""Truncate text to fit within token limit."""
current_tokens = self.count_tokens(text)
if current_tokens <= token_limit:
return text
print(f"Text too long ({current_tokens} tokens). Truncating to {token_limit}...")
# Estimate characters per token
char_per_token = len(text) / current_tokens
target_chars = int(token_limit * char_per_token * 0.9) # 90% to be safe
truncated = text[:target_chars]
return truncated
class Qwen25VLInferencer:
"""Handle inference with Qwen2.5-VL-3B model - FIXED meta tensor issue."""
class Qwen25VLInferencer:
"""Handle inference with Qwen2.5-VL-3B model - FIXED meta tensor issue."""
def __init__(self, model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct", device: str = "cuda"):
"""Initialize Qwen2.5-VL model with proper device handling."""
self.device = device if torch.cuda.is_available() else "cpu"
print(f"Loading Qwen2.5-VL-3B model on device: {self.device}")
try:
# FIXED: Load model without device_map first, then move to device
# This avoids the meta tensor issue
# Determine data type based on device
if self.device == "cuda":
dtype = torch.float16 # GPU: use half precision
else:
dtype = torch.float32 # CPU: use full precision
print(f"Using dtype: {dtype}")
# Load model
print("Loading model weights...")
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=dtype,
trust_remote_code=True,
# IMPORTANT: Don't use device_map="auto" here - causes meta tensor issue
)
# Move to device explicitly AFTER loading
print(f"Moving model to {self.device}...")
if self.device == "cuda":
self.model = self.model.to("cuda")
else:
self.model = self.model.to("cpu")
# Set to evaluation mode
self.model.eval()
print("✅ Model loaded successfully")
except RuntimeError as e:
if "meta tensor" in str(e):
print(f"⚠️ Meta tensor error detected: {e}")
print("Falling back to CPU mode...")
self.device = "cpu"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float32,
trust_remote_code=True,
)
self.model = self.model.to("cpu")
self.model.eval()
print("✅ Model loaded on CPU")
else:
raise
except Exception as e:
print(f"❌ Error loading model: {e}")
print("Trying fallback CPU loading...")
self.device = "cpu"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float32,
trust_remote_code=True,
)
self.model = self.model.to("cpu")
self.model.eval()
# Load processor
print("Loading processor...")
self.processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
# Initialize token chunker
self.token_chunker = TokenChunker(model_name)
print("✅ Model initialization complete")
def _prepare_text_message(self, text: str) -> List[Dict]:
"""Prepare text-only message for the model."""
return [{"type": "text", "text": text}]
def _prepare_image_text_message(self, image_path: str, text: str) -> List[Dict]:
"""Prepare message with image and text."""
return [
{"type": "image", "image": image_path},
{"type": "text", "text": text}
]
def generate_answer(
self,
query: str,
retrieved_docs: List[Dict],
retrieved_images: List[str] = None,
max_new_tokens: int = 128
) -> str:
"""
Generate answer based on query and retrieved documents.
FIXED: Includes token chunking and context length management
"""
# Build context from retrieved documents
context = "КОНТЕКСТ ИЗ ДОКУМЕНТОВ:\n"
for doc in retrieved_docs:
relevance = doc.get('relevance_score', 0)
context += f"\n[Релевантность: {relevance:.2f}]\n{doc['document']}\n"
# FIXED: Truncate context if too long
context = self.token_chunker.truncate_to_token_limit(context, token_limit=50000)
# Build system prompt
system_prompt = "Ты помощник для анализа документов. Используй предоставленный контекст для ответа на вопросы. Отвечай на русском языке. Будь кратким и точным."
# Prepare the full query
full_query = f"{system_prompt}\n\n{context}\n\nВопрос: {query}\n\nОтвет:"
# FIXED: Check and limit token count
query_tokens = self.token_chunker.count_tokens(full_query)
print(f"Query token count: {query_tokens}")
if query_tokens > 100000:
print(f"Query exceeds token limit. Reducing context...")
# Keep only first 3 documents instead of all
context = "КОНТЕКСТ ИЗ ДОКУМЕНТОВ:\n"
for doc in retrieved_docs[:3]:
relevance = doc.get('relevance_score', 0)
context += f"\n[Релевантность: {relevance:.2f}]\n{doc['document']}\n"
context = self.token_chunker.truncate_to_token_limit(context, token_limit=30000)
full_query = f"{system_prompt}\n\n{context}\n\nВопрос: {query}\n\nОтвет:"
# Prepare messages
messages = self._prepare_text_message(full_query)
# If images are provided, add them
if retrieved_images and len(retrieved_images) > 0:
try:
image_message = self._prepare_image_text_message(
retrieved_images[0],
f"Проанализируй это изображение в контексте вопроса: {query}"
)
messages = image_message + [{"type": "text", "text": full_query}]
except Exception as e:
print(f"Warning: Could not include images: {e}")
# Process vision info if images are included
image_inputs = []
video_inputs = []
try:
if any(msg.get('type') == 'image' for msg in messages):
image_inputs, video_inputs = process_vision_info(messages)
except Exception as e:
print(f"Warning: Could not process images: {e}")
# Prepare inputs for model
try:
inputs = self.processor(
text=[full_query],
images=image_inputs if image_inputs else None,
videos=video_inputs if video_inputs else None,
padding=True,
return_tensors='pt',
)
except Exception as e:
print(f"Error preparing inputs: {e}")
return f"Error preparing inputs: {e}"
# Move inputs to device
if self.device == "cuda":
inputs = inputs.to("cuda")
# Generate response with error handling
try:
with torch.no_grad():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=min(max_new_tokens, 512), # Cap at 512
num_beams=1,
do_sample=False
)
except Exception as e:
print(f"Error during generation: {e}")
return f"Error generating response: {e}"
# Decode output
try:
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return response[0] if response else "Could not generate response"
except Exception as e:
print(f"Error decoding response: {e}")
return f"Error decoding response: {e}"
def summarize_document(
self,
document_text: str,
max_new_tokens: int = 512
) -> str:
"""Summarize a document with token limit management."""
# FIXED: Truncate document to fit in context
document_text = self.token_chunker.truncate_to_token_limit(
document_text,
token_limit=40000
)
prompt = f"""Пожалуйста, создай подробное резюме следующего документа на русском языке.
Документ:
{document_text}
Резюме:"""
messages = self._prepare_text_message(prompt)
try:
inputs = self.processor(
text=[prompt],
padding=True,
return_tensors='pt',
)
if self.device == "cuda":
inputs = inputs.to("cuda")
with torch.no_grad():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=min(max_new_tokens, 512),
num_beams=1,
do_sample=False
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return response[0] if response else "Could not generate summary"
except Exception as e:
print(f"Error generating summary: {e}")
return f"Error: {e}"
class RAGPipeline:
"""Complete RAG pipeline combining retrieval and generation."""
def __init__(self, chroma_manager, device: str = "cuda"):
"""Initialize RAG pipeline."""
self.chroma_manager = chroma_manager
self.inferencer = Qwen25VLInferencer(device=device)
def answer_question(
self,
query: str,
n_retrieved: int = 5,
max_new_tokens: int = 512
) -> Dict:
"""
Answer user question using RAG pipeline.
1. Retrieve relevant documents
2. Generate answer using Qwen2.5-VL
"""
# Step 1: Retrieve
retrieved_docs = self.chroma_manager.search(query, n_results=n_retrieved)
if not retrieved_docs:
return {
"answer": "Не найдены релевантные документы для ответа на вопрос.",
"retrieved_docs": [],
"query": query,
"error": "No documents found"
}
# Extract images from retrieved results if available
retrieved_images = []
# Step 2: Generate
try:
answer = self.inferencer.generate_answer(
query=query,
retrieved_docs=retrieved_docs,
retrieved_images=retrieved_images,
max_new_tokens=max_new_tokens
)
except Exception as e:
answer = f"Error generating answer: {e}"
return {
"answer": answer,
"retrieved_docs": retrieved_docs,
"query": query,
"model": "Qwen2.5-VL-3B",
"doc_count": len(retrieved_docs)
}
def summarize_all_documents(self, max_chars: int = 100000) -> str:
"""Create summary of all indexed documents with token limits."""
collection_info = self.chroma_manager.get_collection_info()
doc_count = collection_info['document_count']
if doc_count == 0:
return "No documents in database to summarize."
# Retrieve documents
try:
all_docs = self.chroma_manager.collection.get(include=['documents'])
if not all_docs['documents']:
return "Could not retrieve documents for summarization."
# Combine first documents with char limit
combined_text = ""
for doc in all_docs['documents'][:10]: # Max 10 docs
if len(combined_text) + len(doc) < max_chars:
combined_text += doc + "\n\n"
else:
break
if not combined_text:
combined_text = all_docs['documents'][0][:max_chars]
summary = self.inferencer.summarize_document(combined_text)
return summary
except Exception as e:
return f"Error summarizing documents: {e}"