import gradio as gr import torch import sys import os import re import json import time from datetime import datetime from pathlib import Path # Add the project root to Python path project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from src.inference.inference import tokenizer, model # Import from your inference.py from src.vector_db.manager import ChromaVectorDBManager from src.utils.performance import PerformanceMonitor import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Performance history file PERFORMANCE_HISTORY_FILE = Path("performance_history.json") def save_performance_metrics(metrics_data): """Save performance metrics to history file""" try: if PERFORMANCE_HISTORY_FILE.exists(): with open(PERFORMANCE_HISTORY_FILE, 'r') as f: history = json.load(f) else: history = [] history.append(metrics_data) with open(PERFORMANCE_HISTORY_FILE, 'w') as f: json.dump(history, f, indent=2) except Exception as e: logger.error(f"Failed to save performance metrics: {e}") def calculate_performance_metrics(start_time, end_time, prompt_tokens, generated_tokens, peak_memory_mb): """Calculate performance metrics similar to the requested format""" inference_time = end_time - start_time total_tokens = prompt_tokens + generated_tokens # Calculate throughput (tokens per second) throughput = total_tokens / inference_time if inference_time > 0 else 0 # Calculate inference latency (time per token in milliseconds) latency_ms = (inference_time * 1000) / total_tokens if total_tokens > 0 else 0 return { "timestamp": datetime.now().isoformat(), "model": "Gemma-3-270M", "load_time_s": "N/A", # Model is already loaded "inference_latency_ms": round(latency_ms, 2), "throughput_tokens_s": round(throughput, 2), "ram_usage_mb": round(peak_memory_mb, 2), "vram_usage_mb": 0, # CPU-only model "energy_j": "N/A", # Would require specialized monitoring "prompt_tokens": prompt_tokens, "generated_tokens": generated_tokens, "total_inference_time_s": round(inference_time, 3) } # Initialize Vector DB Manager try: logger.info("Initializing ChromaDB manager") db_manager = ChromaVectorDBManager() # Check if collection has data stats = db_manager.get_collection_stats() logger.info(f"Database stats: {stats}") if stats.get('total_chunks', 0) == 0: logger.warning("No chunks found in database. Processing chunks...") success = db_manager.process_all_chunks() if not success: logger.error("Failed to process chunks") else: logger.info("Chunks processed successfully") except Exception as e: logger.error(f"Failed to initialize vector database: {e}") raise def chat_with_rag(user_query, show_context=False): """Chat function with RAG support using your existing model setup.""" try: if not user_query.strip(): return "Please enter a question.", "" logger.info(f"Processing query: {user_query}") # Get top-k relevant chunks results = db_manager.search_for_rag( user_query, n_results=3, use_truncated=True, filter_128_context=True ) if not results or results[0]['score'] < 0.5: return "I can only answer questions based on the provided car manuals. Please ask a question related to car maintenance or operation.", "" # Build context with source information context_parts = [] source_info = [] for i, result in enumerate(results, 1): context_parts.append(result["text"]) source_info.append(f"Source {i}: {result['source_file']} (Score: {result['score']:.3f})") context = "\n\n".join(context_parts) # Clean the context before feeding it to the model cleaned_context = re.sub(r'(\s*\.\s*){3,}', ' ', context) # Remove long series of dots cleaned_context = re.sub(r'\s+', ' ', cleaned_context).strip() # Normalize whitespace sources = "\n".join(source_info) # Build a more conversational and helpful prompt prompt = f"""You are a specialized car manual assistant. Your sole purpose is to answer questions based ONLY on the provided text from car manuals. Strictly follow these rules: 1. Base your answer *exclusively* on the "CONTEXT" provided below. 2. Synthesize a complete and coherent answer. Do not repeat fragments of the context. 3. If the answer is not found in the CONTEXT, you MUST state: "I'm sorry, but the answer to your question is not available in the provided car manuals." 4. Do not use any external knowledge or make up information. CONTEXT: --- START OF CONTEXT --- {cleaned_context} --- END OF CONTEXT --- QUESTION: {user_query} ANSWER:""" # Count prompt tokens prompt_tokens = len(tokenizer.encode(prompt)) # Use inference setup with performance monitoring inputs = tokenizer(prompt, return_tensors="pt").to("cpu") # Start performance monitoring for inference with PerformanceMonitor("Model_Inference") as monitor: start_time = time.time() # Generate response with conservative parameters for gemma-3-270m with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) end_time = time.time() # Decode and clean response full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the generated part (remove the original prompt) answer = full_response[len(prompt):].strip() # Count generated tokens generated_tokens = len(tokenizer.encode(answer)) # Get performance metrics from monitor perf_metrics = monitor.stop_monitoring() # Calculate and save performance metrics metrics_data = calculate_performance_metrics( start_time, end_time, prompt_tokens, generated_tokens, perf_metrics.peak_memory ) # Save to history save_performance_metrics(metrics_data) # Log performance summary logger.info(f"Performance Metrics:") logger.info(f" Model: {metrics_data['model']}") logger.info(f" Inference Latency: {metrics_data['inference_latency_ms']} ms") logger.info(f" Throughput: {metrics_data['throughput_tokens_s']} tokens/s") logger.info(f" RAM Usage: {metrics_data['ram_usage_mb']} MB") logger.info(f" Tokens (prompt/generated): {metrics_data['prompt_tokens']}/{metrics_data['generated_tokens']}") if not answer: answer = "I apologize, but I couldn't generate a proper response. Please try rephrasing your question." logger.info(f"Generated response length: {len(answer)} characters") # Add performance info to sources perf_info = f"\n\n**Performance Metrics:**\n" \ f"- Model: {metrics_data['model']}\n" \ f"- Inference Latency: {metrics_data['inference_latency_ms']} ms\n" \ f"- Throughput: {metrics_data['throughput_tokens_s']} tokens/s\n" \ f"- RAM Usage: {metrics_data['ram_usage_mb']} MB\n" \ f"- Total Inference Time: {metrics_data['total_inference_time_s']} s" # Return answer and sources if requested if show_context: return answer, f"**Sources Used:**\n{sources}\n\n**Context:**\n{context}{perf_info}" else: return answer, f"**Sources Used:**\n{sources}{perf_info}" except Exception as e: logger.error(f"Error in chat_with_rag: {e}") return f"Sorry, I encountered an error: {str(e)}", "" # Gradio Interface with gr.Blocks(title="Car Manual Assistant", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🚗 Car Manual Assistant Ask questions about car maintenance and operations. Uses **Google Gemma-3-270M** model with RAG from car manuals. """) with gr.Row(): with gr.Column(scale=3): user_input = gr.Textbox( label="Ask a question about your car", placeholder="e.g., How do I change the engine oil? What is the recommended tire pressure?", lines=2 ) with gr.Row(): submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") with gr.Row(): with gr.Column(): answer_output = gr.Textbox( label="Answer", lines=6, interactive=False ) sources_output = gr.Markdown(label="Sources") # Example questions gr.Examples( examples=[ "How do I change the engine oil?", "What is the recommended tire pressure?", "How to check brake fluid level?", "When should I replace the air filter?", "How to jump start the car?", "What does the check engine light mean?", "How often should I service my car?", "How to change a flat tire?" ], inputs=user_input ) # Event handlers submit_btn.click( chat_with_rag, inputs=[user_input], outputs=[answer_output, sources_output] ) user_input.submit( chat_with_rag, inputs=[user_input], outputs=[answer_output, sources_output] ) clear_btn.click( lambda: ("", "", ""), outputs=[user_input, answer_output, sources_output] ) # Launch UI if __name__ == "__main__": logger.info("Launching Gradio interface...") logger.info("Model loaded from inference.py - google/gemma-3-270m on CPU") demo.launch()