|
|
from typing import List, Dict, Optional, Tuple |
|
|
import torch |
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoTokenizer |
|
|
from qwen_vl_utils import process_vision_info |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
|
|
|
class TokenChunker: |
|
|
"""Handle token counting and chunking for model context limits.""" |
|
|
|
|
|
def __init__(self, model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"): |
|
|
"""Initialize tokenizer for token counting.""" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
self.max_tokens = 100000 |
|
|
|
|
|
def count_tokens(self, text: str) -> int: |
|
|
"""Count tokens in text.""" |
|
|
try: |
|
|
tokens = self.tokenizer.encode(text, add_special_tokens=False) |
|
|
return len(tokens) |
|
|
except Exception as e: |
|
|
print(f"Error counting tokens: {e}") |
|
|
|
|
|
return len(text) // 4 |
|
|
|
|
|
def chunk_text(self, text: str, chunk_size: int = 50000) -> List[str]: |
|
|
"""Split text into chunks that fit within token limits.""" |
|
|
if len(text) <= chunk_size: |
|
|
return [text] |
|
|
|
|
|
chunks = [] |
|
|
current_chunk = "" |
|
|
|
|
|
|
|
|
paragraphs = text.split("\n\n") |
|
|
|
|
|
for paragraph in paragraphs: |
|
|
if len(current_chunk) + len(paragraph) < chunk_size: |
|
|
current_chunk += paragraph + "\n\n" |
|
|
else: |
|
|
if current_chunk: |
|
|
chunks.append(current_chunk.strip()) |
|
|
current_chunk = paragraph + "\n\n" |
|
|
|
|
|
if current_chunk: |
|
|
chunks.append(current_chunk.strip()) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def truncate_to_token_limit(self, text: str, token_limit: int = 50000) -> str: |
|
|
"""Truncate text to fit within token limit.""" |
|
|
current_tokens = self.count_tokens(text) |
|
|
|
|
|
if current_tokens <= token_limit: |
|
|
return text |
|
|
|
|
|
print(f"Text too long ({current_tokens} tokens). Truncating to {token_limit}...") |
|
|
|
|
|
|
|
|
char_per_token = len(text) / current_tokens |
|
|
target_chars = int(token_limit * char_per_token * 0.9) |
|
|
|
|
|
truncated = text[:target_chars] |
|
|
return truncated |
|
|
|
|
|
|
|
|
class Qwen25VLInferencer: |
|
|
"""Handle inference with Qwen2.5-VL-3B model - FIXED meta tensor issue.""" |
|
|
|
|
|
class Qwen25VLInferencer: |
|
|
"""Handle inference with Qwen2.5-VL-3B model - FIXED meta tensor issue.""" |
|
|
|
|
|
def __init__(self, model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct", device: str = "cuda"): |
|
|
"""Initialize Qwen2.5-VL model with proper device handling.""" |
|
|
self.device = device if torch.cuda.is_available() else "cpu" |
|
|
print(f"Loading Qwen2.5-VL-3B model on device: {self.device}") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
dtype = torch.float16 |
|
|
else: |
|
|
dtype = torch.float32 |
|
|
|
|
|
print(f"Using dtype: {dtype}") |
|
|
|
|
|
|
|
|
print("Loading model weights...") |
|
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=dtype, |
|
|
trust_remote_code=True, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
print(f"Moving model to {self.device}...") |
|
|
if self.device == "cuda": |
|
|
self.model = self.model.to("cuda") |
|
|
else: |
|
|
self.model = self.model.to("cpu") |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
print("✅ Model loaded successfully") |
|
|
|
|
|
except RuntimeError as e: |
|
|
if "meta tensor" in str(e): |
|
|
print(f"⚠️ Meta tensor error detected: {e}") |
|
|
print("Falling back to CPU mode...") |
|
|
self.device = "cpu" |
|
|
|
|
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float32, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self.model = self.model.to("cpu") |
|
|
self.model.eval() |
|
|
print("✅ Model loaded on CPU") |
|
|
else: |
|
|
raise |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error loading model: {e}") |
|
|
print("Trying fallback CPU loading...") |
|
|
|
|
|
self.device = "cpu" |
|
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float32, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self.model = self.model.to("cpu") |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
print("Loading processor...") |
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
self.token_chunker = TokenChunker(model_name) |
|
|
|
|
|
print("✅ Model initialization complete") |
|
|
|
|
|
def _prepare_text_message(self, text: str) -> List[Dict]: |
|
|
"""Prepare text-only message for the model.""" |
|
|
return [{"type": "text", "text": text}] |
|
|
|
|
|
def _prepare_image_text_message(self, image_path: str, text: str) -> List[Dict]: |
|
|
"""Prepare message with image and text.""" |
|
|
return [ |
|
|
{"type": "image", "image": image_path}, |
|
|
{"type": "text", "text": text} |
|
|
] |
|
|
|
|
|
def generate_answer( |
|
|
self, |
|
|
query: str, |
|
|
retrieved_docs: List[Dict], |
|
|
retrieved_images: List[str] = None, |
|
|
max_new_tokens: int = 128 |
|
|
) -> str: |
|
|
""" |
|
|
Generate answer based on query and retrieved documents. |
|
|
FIXED: Includes token chunking and context length management |
|
|
""" |
|
|
|
|
|
context = "КОНТЕКСТ ИЗ ДОКУМЕНТОВ:\n" |
|
|
for doc in retrieved_docs: |
|
|
relevance = doc.get('relevance_score', 0) |
|
|
context += f"\n[Релевантность: {relevance:.2f}]\n{doc['document']}\n" |
|
|
|
|
|
|
|
|
context = self.token_chunker.truncate_to_token_limit(context, token_limit=50000) |
|
|
|
|
|
|
|
|
system_prompt = "Ты помощник для анализа документов. Используй предоставленный контекст для ответа на вопросы. Отвечай на русском языке. Будь кратким и точным." |
|
|
|
|
|
|
|
|
full_query = f"{system_prompt}\n\n{context}\n\nВопрос: {query}\n\nОтвет:" |
|
|
|
|
|
|
|
|
query_tokens = self.token_chunker.count_tokens(full_query) |
|
|
print(f"Query token count: {query_tokens}") |
|
|
|
|
|
if query_tokens > 100000: |
|
|
print(f"Query exceeds token limit. Reducing context...") |
|
|
|
|
|
context = "КОНТЕКСТ ИЗ ДОКУМЕНТОВ:\n" |
|
|
for doc in retrieved_docs[:3]: |
|
|
relevance = doc.get('relevance_score', 0) |
|
|
context += f"\n[Релевантность: {relevance:.2f}]\n{doc['document']}\n" |
|
|
|
|
|
context = self.token_chunker.truncate_to_token_limit(context, token_limit=30000) |
|
|
full_query = f"{system_prompt}\n\n{context}\n\nВопрос: {query}\n\nОтвет:" |
|
|
|
|
|
|
|
|
messages = self._prepare_text_message(full_query) |
|
|
|
|
|
|
|
|
if retrieved_images and len(retrieved_images) > 0: |
|
|
try: |
|
|
image_message = self._prepare_image_text_message( |
|
|
retrieved_images[0], |
|
|
f"Проанализируй это изображение в контексте вопроса: {query}" |
|
|
) |
|
|
messages = image_message + [{"type": "text", "text": full_query}] |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not include images: {e}") |
|
|
|
|
|
|
|
|
image_inputs = [] |
|
|
video_inputs = [] |
|
|
|
|
|
try: |
|
|
if any(msg.get('type') == 'image' for msg in messages): |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not process images: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
inputs = self.processor( |
|
|
text=[full_query], |
|
|
images=image_inputs if image_inputs else None, |
|
|
videos=video_inputs if video_inputs else None, |
|
|
padding=True, |
|
|
return_tensors='pt', |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error preparing inputs: {e}") |
|
|
return f"Error preparing inputs: {e}" |
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
inputs = inputs.to("cuda") |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
generated_ids = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=min(max_new_tokens, 512), |
|
|
num_beams=1, |
|
|
do_sample=False |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error during generation: {e}") |
|
|
return f"Error generating response: {e}" |
|
|
|
|
|
|
|
|
try: |
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] |
|
|
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
|
|
|
response = self.processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
) |
|
|
|
|
|
return response[0] if response else "Could not generate response" |
|
|
except Exception as e: |
|
|
print(f"Error decoding response: {e}") |
|
|
return f"Error decoding response: {e}" |
|
|
|
|
|
def summarize_document( |
|
|
self, |
|
|
document_text: str, |
|
|
max_new_tokens: int = 512 |
|
|
) -> str: |
|
|
"""Summarize a document with token limit management.""" |
|
|
|
|
|
|
|
|
document_text = self.token_chunker.truncate_to_token_limit( |
|
|
document_text, |
|
|
token_limit=40000 |
|
|
) |
|
|
|
|
|
prompt = f"""Пожалуйста, создай подробное резюме следующего документа на русском языке. |
|
|
|
|
|
Документ: |
|
|
{document_text} |
|
|
|
|
|
Резюме:""" |
|
|
|
|
|
messages = self._prepare_text_message(prompt) |
|
|
|
|
|
try: |
|
|
inputs = self.processor( |
|
|
text=[prompt], |
|
|
padding=True, |
|
|
return_tensors='pt', |
|
|
) |
|
|
|
|
|
if self.device == "cuda": |
|
|
inputs = inputs.to("cuda") |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=min(max_new_tokens, 512), |
|
|
num_beams=1, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] |
|
|
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
|
|
|
response = self.processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
) |
|
|
|
|
|
return response[0] if response else "Could not generate summary" |
|
|
except Exception as e: |
|
|
print(f"Error generating summary: {e}") |
|
|
return f"Error: {e}" |
|
|
|
|
|
|
|
|
class RAGPipeline: |
|
|
"""Complete RAG pipeline combining retrieval and generation.""" |
|
|
|
|
|
def __init__(self, chroma_manager, device: str = "cuda"): |
|
|
"""Initialize RAG pipeline.""" |
|
|
self.chroma_manager = chroma_manager |
|
|
self.inferencer = Qwen25VLInferencer(device=device) |
|
|
|
|
|
def answer_question( |
|
|
self, |
|
|
query: str, |
|
|
n_retrieved: int = 5, |
|
|
max_new_tokens: int = 512 |
|
|
) -> Dict: |
|
|
""" |
|
|
Answer user question using RAG pipeline. |
|
|
1. Retrieve relevant documents |
|
|
2. Generate answer using Qwen2.5-VL |
|
|
""" |
|
|
|
|
|
retrieved_docs = self.chroma_manager.search(query, n_results=n_retrieved) |
|
|
|
|
|
if not retrieved_docs: |
|
|
return { |
|
|
"answer": "Не найдены релевантные документы для ответа на вопрос.", |
|
|
"retrieved_docs": [], |
|
|
"query": query, |
|
|
"error": "No documents found" |
|
|
} |
|
|
|
|
|
|
|
|
retrieved_images = [] |
|
|
|
|
|
|
|
|
try: |
|
|
answer = self.inferencer.generate_answer( |
|
|
query=query, |
|
|
retrieved_docs=retrieved_docs, |
|
|
retrieved_images=retrieved_images, |
|
|
max_new_tokens=max_new_tokens |
|
|
) |
|
|
except Exception as e: |
|
|
answer = f"Error generating answer: {e}" |
|
|
|
|
|
return { |
|
|
"answer": answer, |
|
|
"retrieved_docs": retrieved_docs, |
|
|
"query": query, |
|
|
"model": "Qwen2.5-VL-3B", |
|
|
"doc_count": len(retrieved_docs) |
|
|
} |
|
|
|
|
|
def summarize_all_documents(self, max_chars: int = 100000) -> str: |
|
|
"""Create summary of all indexed documents with token limits.""" |
|
|
collection_info = self.chroma_manager.get_collection_info() |
|
|
doc_count = collection_info['document_count'] |
|
|
|
|
|
if doc_count == 0: |
|
|
return "No documents in database to summarize." |
|
|
|
|
|
|
|
|
try: |
|
|
all_docs = self.chroma_manager.collection.get(include=['documents']) |
|
|
|
|
|
if not all_docs['documents']: |
|
|
return "Could not retrieve documents for summarization." |
|
|
|
|
|
|
|
|
combined_text = "" |
|
|
for doc in all_docs['documents'][:10]: |
|
|
if len(combined_text) + len(doc) < max_chars: |
|
|
combined_text += doc + "\n\n" |
|
|
else: |
|
|
break |
|
|
|
|
|
if not combined_text: |
|
|
combined_text = all_docs['documents'][0][:max_chars] |
|
|
|
|
|
summary = self.inferencer.summarize_document(combined_text) |
|
|
return summary |
|
|
except Exception as e: |
|
|
return f"Error summarizing documents: {e}" |