VivekanandaAI / test_mistral.py
jyotirmoy05's picture
Upload 18 files
3c15254 verified
"""
Test Mistral GGUF Model
Load and test local GGUF model with llama-cpp-python
"""
import re
import sys
from pathlib import Path
from typing import Optional, Dict, Any
import time
# llama-cpp-python import
try:
from llama_cpp import Llama
HAS_LLAMA_CPP = True
except ImportError:
HAS_LLAMA_CPP = False
print("[ERROR] llama-cpp-python not installed!")
print("Install with: pip install llama-cpp-python")
sys.exit(1)
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
# Local imports
from utils import get_utils
# ============================================================================
# MODEL LOADER
# ============================================================================
class MistralLoader:
"""Load and manage Mistral GGUF model"""
def __init__(self, utils):
self.utils = utils
self.config = utils.config
self.logger = utils.logger
self.model = None
def load_model(self) -> Llama:
"""Load GGUF model with config parameters"""
# Get model path
# Use nested paths config for consistency
model_dir = self.config.get_path('paths', 'models', 'base')
model_file = self.config.get('model.base.file')
model_path = model_dir / model_file
if not model_path.exists():
raise FileNotFoundError(
f"Model not found: {model_path}\n"
f"Please place {model_file} in {model_dir}/"
)
self.logger.info(f"Loading model: {model_path}")
self.logger.info("This may take 1-2 minutes...")
# Get generation config
gen_config = self.config.get('model.generation', {})
context_window = gen_config.get('context_window', 4096)
n_batch = gen_config.get('n_batch', 512)
n_threads = gen_config.get('n_threads', 4)
n_gpu_layers = gen_config.get('n_gpu_layers', -1)
# Log configuration
self.logger.info(f"Configuration:")
self.logger.info(f" Context window: {context_window}")
self.logger.info(f" Batch size: {n_batch}")
self.logger.info(f" Threads: {n_threads}")
self.logger.info(f" GPU layers: {n_gpu_layers} ({'all' if n_gpu_layers == -1 else n_gpu_layers})")
# Load model
start_time = time.time()
self.model = Llama(
model_path=str(model_path),
n_ctx=context_window,
n_batch=n_batch,
n_threads=n_threads,
n_gpu_layers=n_gpu_layers,
verbose=False
)
load_time = time.time() - start_time
self.logger.info(f"Model loaded in {load_time:.2f} seconds")
return self.model
def sanitize_output(self, text: str) -> str:
# Normalize whitespace
text = text.strip()
# Remove bracketed numeric citations like [12], [3], etc.
text = re.sub(r"\s*\[[0-9]+\]\s*", " ", text)
# Remove common third-person attribution phrases and passage references
patterns_to_remove = [
r"\b(?:Swami\s+)?Vivekananda\s+(?:said|says|writes|wrote|taught)\b[^\.\n]*[\.\n]?",
r"\b[Ii]n\s+[Pp]assage\s+\d+[^\.\n]*[\.\n]?",
r"\b[Pp]assage\s+\d+[:\-]\s*",
r"\b[Cc]ontext\s*[:\-]\s*",
]
for pat in patterns_to_remove:
text = re.sub(pat, " ", text)
# Remove label prefixes to enforce natural prose (no sections/labels)
label_prefixes = [
r"Direct Answer",
r"Evidence from Dataset",
r"Intellectual Fire",
r"Cultural Pride",
r"Moral Strength",
r"Universalism",
r"Call to Action",
r"Guidance",
]
label_pattern = r"(^|\n)\s*(?:" + "|".join(label_prefixes) + r")\s*:\s*"
text = re.sub(label_pattern, "\n", text)
# Strip markdown bold/italic markers
text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
text = re.sub(r"\*([^*]+)\*", r"\1", text)
# Remove list/numbered prefixes to avoid enumerated sections
text = re.sub(r"(?m)^\s*[-*]\s+", "", text)
text = re.sub(r"(?m)^\s*\d+\s*[\.)]\s+", "", text)
# Replace third-person name with first-person "I" when used standalone
text = re.sub(r"\b(?:Swami\s+)?Vivekananda\b", "I", text)
# Collapse multiple spaces and newlines
text = re.sub(r"[ \t]{2,}", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
# Ensure benediction at end
if not re.search(r"\bOm\s+Shanti\b", text, flags=re.IGNORECASE):
text = text.rstrip() + "\n\nOm Shanti."
return text.strip()
def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
"""
Generate response from model
Args:
prompt: Input prompt
**kwargs: Override generation parameters
"""
if not self.model:
raise RuntimeError("Model not loaded. Call load_model() first.")
# Get generation parameters from config
gen_config = self.config.get('model.generation', {})
max_tokens = kwargs.get('max_tokens', gen_config.get('max_tokens', 512))
# Use a lower default temperature to reduce philosophical drift and increase factuality
temperature = kwargs.get('temperature', 0.3 if gen_config.get('temperature') is None else gen_config.get('temperature'))
top_p = kwargs.get('top_p', gen_config.get('top_p', 0.85))
top_k = kwargs.get('top_k', gen_config.get('top_k', 40))
repeat_penalty = kwargs.get('repeat_penalty', gen_config.get('repeat_penalty', 1.1))
self.logger.info(f"\nGenerating response...")
self.logger.info(f"Parameters: max_tokens={max_tokens}, temp={temperature}, top_p={top_p}")
# Generate
start_time = time.time()
output = self.model(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
stop=["</s>", "\n\n\n"], # Stop sequences
echo=False # Don't repeat the prompt
)
gen_time = time.time() - start_time
# Extract response
response_text = output['choices'][0]['text'].strip()
sanitized_text = self.sanitize_output(response_text)
# Calculate tokens per second
tokens_generated = output['usage']['completion_tokens']
tokens_per_sec = tokens_generated / gen_time if gen_time > 0 else 0
result = {
'response': sanitized_text,
'prompt_tokens': output['usage']['prompt_tokens'],
'completion_tokens': tokens_generated,
'total_tokens': output['usage']['total_tokens'],
'generation_time': gen_time,
'tokens_per_second': tokens_per_sec
}
self.logger.info(f"Generated {tokens_generated} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tokens/s)")
return result
# ============================================================================
# PROMPT BUILDER (with RAG integration)
# ============================================================================
class VivekanandaPromptBuilder:
"""Build prompts for Vivekananda AI"""
def __init__(self, utils):
self.utils = utils
self.config = utils.config
self.prompts = self.config.get('prompts', {})
self.system_prompt = self.prompts.get('system', '')
self.rag_template = self.prompts.get('rag_template', {})
self.direct_template = self.prompts.get('direct_template', {})
self.guardrails = self.prompts.get('guardrails', {})
# Structured guidance to ensure answers meet user's criteria
self.structured_guidance = (
"Answer directly and correctly in natural prose.\n"
"- Start with a concise direct answer (1–2 lines).\n"
"- Ground everything in dataset/context facts; do not invent.\n"
"- Express intellectual fire, cultural pride, moral strength, and universalism naturally—without headings or labels.\n"
"- Avoid numbered lists, headings, or section labels; write coherent paragraphs.\n"
"- If the dataset lacks facts, state that plainly. At most one short quote if essential.\n"
"- Conclude with an uplifting benediction when appropriate (e.g., Om Shanti).\n"
)
# Minimal known facts to ensure correctness for direct biographical queries
self.known_facts = {
'parents': "Father: Vishwanath Datta; Mother: Bhuvaneshwari Devi.",
'birth_name': "Narendranath Datta.",
'born': "12 January 1863, Calcutta (now Kolkata).",
}
def _known_facts_context(self, question: str) -> str:
"""Return a small context block of known biographical facts when relevant."""
q = question.lower().strip()
parts = []
if any(kw in q for kw in ["who were your parents", "your parents", "parents", "mother", "father"]):
parts.append(self.known_facts['parents'])
if any(kw in q for kw in ["birth name", "original name", "given name"]):
parts.append(f"Birth name: {self.known_facts['birth_name']}")
if any(kw in q for kw in ["when were you born", "date of birth", "born"]):
parts.append(f"Born: {self.known_facts['born']}")
if not parts:
return ""
return "Known biographical facts (for correctness):\n" + " \n".join(parts)
def build_mistral_format(self, question: str, context: Optional[str] = None) -> str:
"""Build prompt in Mistral instruction format"""
# Reinforce directness and dataset-grounding beyond config system prompt
system_prompt = (
self.system_prompt
+ "\n\n"
+ self.guardrails.get('direct_address', '')
+ "\n\nDo not default to abstract philosophical discourse; answer directly and factually."
+ " Use dataset/context truths only; if not present, clearly say 'I do not have enough context to answer precisely.'"
)
if context:
synthesis_hint = self.guardrails.get('synthesis_hint', '')
header_tpl = self.rag_template.get('header', '')
footer_txt = self.rag_template.get('footer', '')
try:
header = header_tpl.format(context="", question=question)
except Exception:
header = header_tpl
# Inject known facts for direct biographical questions to ensure correctness
known_ctx = self._known_facts_context(question)
context_block = (
"Use the following dataset context silently—do not quote or reference it verbatim;\n"
"ground answers strictly in its facts:\n"
"<CONTEXT>\n" + context + ("\n\n" + known_ctx if known_ctx else "") + "\n</CONTEXT>"
)
full_user_message = (
f"{header}\n\n{context_block}\n\n{self.structured_guidance}\n\n{footer_txt}\n\n{synthesis_hint}\n\n"
f"Question: {question}"
)
else:
known_ctx = self._known_facts_context(question)
full_user_message = (
self.direct_template.get('template', '')
+ ("\n\n" + known_ctx if known_ctx else "")
+ "\n\n" + self.structured_guidance
+ f"\n\nQuestion: {question}"
)
# Mistral Instruct format
prompt = f"<s>[INST] {system_prompt}\n\n{full_user_message} [/INST]"
return prompt
# ============================================================================
# TEST SUITE
# ============================================================================
class ModelTester:
"""Test model with various prompts"""
def __init__(self, utils, loader, prompt_builder):
self.utils = utils
self.logger = utils.logger
self.loader = loader
self.prompt_builder = prompt_builder
def test_single_query(self, question: str, context: Optional[str] = None, **kwargs):
"""Test single query"""
print("\n" + "="*80)
print("MODEL TEST")
print("="*80)
print(f"\nQuestion: {question}")
if context:
print(f"\nContext provided: {len(context)} characters")
# Build prompt
prompt = self.prompt_builder.build_mistral_format(question, context)
print("\n" + "-"*80)
print("PROMPT")
print("-"*80)
print(prompt[:500] + "..." if len(prompt) > 500 else prompt)
# Generate
result = self.loader.generate(prompt, **kwargs)
# Display results
print("\n" + "-"*80)
print("RESPONSE")
print("-"*80)
print(result['response'])
print("\n" + "-"*80)
print("STATISTICS")
print("-"*80)
print(f"Prompt tokens: {result['prompt_tokens']}")
print(f"Completion tokens: {result['completion_tokens']}")
print(f"Total tokens: {result['total_tokens']}")
print(f"Generation time: {result['generation_time']:.2f}s")
print(f"Speed: {result['tokens_per_second']:.1f} tokens/s")
print("="*80)
return result
def test_batch(self, questions: list, **kwargs):
"""Test multiple questions"""
self.logger.info(f"\nTesting {len(questions)} questions...")
results = []
total_time = 0
for idx, question in enumerate(questions, 1):
self.logger.info(f"\n[{idx}/{len(questions)}] Testing: {question}")
result = self.test_single_query(question, **kwargs)
results.append({
'question': question,
'result': result
})
total_time += result['generation_time']
# Summary
print("\n" + "="*80)
print("BATCH TEST SUMMARY")
print("="*80)
print(f"Total questions: {len(questions)}")
print(f"Total time: {total_time:.2f}s")
print(f"Average time per question: {total_time/len(questions):.2f}s")
print(f"Average tokens/s: {sum(r['result']['tokens_per_second'] for r in results) / len(results):.1f}")
print("="*80)
return results
def test_with_rag(self, question: str, **kwargs):
"""Test with RAG retrieval"""
try:
# Import RAG retriever from file path (module name starts with digits)
import importlib.util
rag_path = Path(__file__).parent / "02_query_rag.py"
spec = importlib.util.spec_from_file_location("rag_module", rag_path)
rag_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(rag_module)
RAGRetriever = rag_module.RAGRetriever
# Initialize RAG
retriever = RAGRetriever(self.utils)
retriever.load_vectorstore()
# Retrieve context
results = retriever.retrieve(question)
context = retriever.format_context(results)
# Test with context, forwarding generation params
return self.test_single_query(question, context=context, **kwargs)
except Exception as e:
self.logger.error(f"RAG test failed: {e}")
self.logger.info("Falling back to direct generation...")
return self.test_single_query(question, **kwargs)
# ============================================================================
# INTERACTIVE MODE
# ============================================================================
def interactive_mode(loader, prompt_builder, tester):
"""Interactive testing mode"""
print("\n" + "="*80)
print("🕉️ MISTRAL MODEL - INTERACTIVE MODE")
print("="*80)
print("Commands:")
print(" - Type your question")
print(" - 'rag <question>' - Test with RAG")
print(" - 'params' - Show generation parameters")
print(" - 'quit' or 'exit' - Stop")
print("="*80)
while True:
try:
user_input = input("\n[YOU] ").strip()
if not user_input:
continue
if user_input.lower() in ['quit', 'exit', 'q']:
print("\n[INFO] Om Shanti 🕉️")
break
if user_input.lower() == 'params':
gen_config = loader.config.get('model.generation', {})
print("\nGeneration Parameters:")
for key, value in gen_config.items():
print(f" {key}: {value}")
continue
if user_input.lower().startswith('rag '):
question = user_input[4:].strip()
tester.test_with_rag(question)
else:
tester.test_single_query(user_input)
except KeyboardInterrupt:
print("\n\n[INFO] Interrupted")
break
except Exception as e:
print(f"\n[ERROR] {e}")
import traceback
traceback.print_exc()
# ============================================================================
# MAIN
# ============================================================================
def main():
"""Main execution"""
import argparse
parser = argparse.ArgumentParser(description="Test Mistral GGUF Model")
parser.add_argument('query', nargs='*', help='Question to ask')
parser.add_argument('--rag', action='store_true', help='Use RAG retrieval')
parser.add_argument('--batch', action='store_true', help='Run batch tests')
parser.add_argument('--max-tokens', type=int, help='Override max tokens')
parser.add_argument('--temperature', type=float, help='Override temperature')
args = parser.parse_args()
# Initialize
utils = get_utils()
logger = utils.logger
logger.info("="*80)
logger.info("MISTRAL MODEL TESTING")
logger.info("="*80)
try:
# Initialize components
loader = MistralLoader(utils)
prompt_builder = VivekanandaPromptBuilder(utils)
# Load model
loader.load_model()
# Initialize tester
tester = ModelTester(utils, loader, prompt_builder)
# Prepare kwargs
kwargs = {}
if args.max_tokens:
kwargs['max_tokens'] = args.max_tokens
if args.temperature:
kwargs['temperature'] = args.temperature
if args.batch:
# Batch test mode
test_questions = utils.config.get('evaluation.test_queries', [
"What is Karma Yoga?",
"How can I overcome fear?",
"What is the purpose of meditation?"
])
tester.test_batch(test_questions, **kwargs)
elif args.query:
# Single query mode
question = ' '.join(args.query)
if args.rag:
tester.test_with_rag(question, **kwargs)
else:
tester.test_single_query(question, **kwargs)
else:
# Interactive mode
interactive_mode(loader, prompt_builder, tester)
return 0
except Exception as e:
logger.error(f"\nFATAL ERROR: {e}", exc_info=True)
return 1
if __name__ == "__main__":
sys.exit(main())