CarAssistanceQA / app.py
Nihal2000's picture
Update app.py
5262791 verified
import gradio as gr
import torch
import sys
import os
import re
import json
import time
from datetime import datetime
from pathlib import Path
# Add the project root to Python path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from src.inference.inference import tokenizer, model # Import from your inference.py
from src.vector_db.manager import ChromaVectorDBManager
from src.utils.performance import PerformanceMonitor
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Performance history file
PERFORMANCE_HISTORY_FILE = Path("performance_history.json")
def save_performance_metrics(metrics_data):
"""Save performance metrics to history file"""
try:
if PERFORMANCE_HISTORY_FILE.exists():
with open(PERFORMANCE_HISTORY_FILE, 'r') as f:
history = json.load(f)
else:
history = []
history.append(metrics_data)
with open(PERFORMANCE_HISTORY_FILE, 'w') as f:
json.dump(history, f, indent=2)
except Exception as e:
logger.error(f"Failed to save performance metrics: {e}")
def calculate_performance_metrics(start_time, end_time, prompt_tokens, generated_tokens, peak_memory_mb):
"""Calculate performance metrics similar to the requested format"""
inference_time = end_time - start_time
total_tokens = prompt_tokens + generated_tokens
# Calculate throughput (tokens per second)
throughput = total_tokens / inference_time if inference_time > 0 else 0
# Calculate inference latency (time per token in milliseconds)
latency_ms = (inference_time * 1000) / total_tokens if total_tokens > 0 else 0
return {
"timestamp": datetime.now().isoformat(),
"model": "Gemma-3-270M",
"load_time_s": "N/A", # Model is already loaded
"inference_latency_ms": round(latency_ms, 2),
"throughput_tokens_s": round(throughput, 2),
"ram_usage_mb": round(peak_memory_mb, 2),
"vram_usage_mb": 0, # CPU-only model
"energy_j": "N/A", # Would require specialized monitoring
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"total_inference_time_s": round(inference_time, 3)
}
# Initialize Vector DB Manager
try:
logger.info("Initializing ChromaDB manager")
db_manager = ChromaVectorDBManager()
# Check if collection has data
stats = db_manager.get_collection_stats()
logger.info(f"Database stats: {stats}")
if stats.get('total_chunks', 0) == 0:
logger.warning("No chunks found in database. Processing chunks...")
success = db_manager.process_all_chunks()
if not success:
logger.error("Failed to process chunks")
else:
logger.info("Chunks processed successfully")
except Exception as e:
logger.error(f"Failed to initialize vector database: {e}")
raise
def chat_with_rag(user_query, show_context=False):
"""Chat function with RAG support using your existing model setup."""
try:
if not user_query.strip():
return "Please enter a question.", ""
logger.info(f"Processing query: {user_query}")
# Get top-k relevant chunks
results = db_manager.search_for_rag(
user_query,
n_results=3,
use_truncated=True,
filter_128_context=True
)
if not results or results[0]['score'] < 0.5:
return "I can only answer questions based on the provided car manuals. Please ask a question related to car maintenance or operation.", ""
# Build context with source information
context_parts = []
source_info = []
for i, result in enumerate(results, 1):
context_parts.append(result["text"])
source_info.append(f"Source {i}: {result['source_file']} (Score: {result['score']:.3f})")
context = "\n\n".join(context_parts)
# Clean the context before feeding it to the model
cleaned_context = re.sub(r'(\s*\.\s*){3,}', ' ', context) # Remove long series of dots
cleaned_context = re.sub(r'\s+', ' ', cleaned_context).strip() # Normalize whitespace
sources = "\n".join(source_info)
# Build a more conversational and helpful prompt
prompt = f"""You are a specialized car manual assistant. Your sole purpose is to answer questions based ONLY on the provided text from car manuals.
Strictly follow these rules:
1. Base your answer *exclusively* on the "CONTEXT" provided below.
2. Synthesize a complete and coherent answer. Do not repeat fragments of the context.
3. If the answer is not found in the CONTEXT, you MUST state: "I'm sorry, but the answer to your question is not available in the provided car manuals."
4. Do not use any external knowledge or make up information.
CONTEXT:
--- START OF CONTEXT ---
{cleaned_context}
--- END OF CONTEXT ---
QUESTION:
{user_query}
ANSWER:"""
# Count prompt tokens
prompt_tokens = len(tokenizer.encode(prompt))
# Use inference setup with performance monitoring
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
# Start performance monitoring for inference
with PerformanceMonitor("Model_Inference") as monitor:
start_time = time.time()
# Generate response with conservative parameters for gemma-3-270m
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
end_time = time.time()
# Decode and clean response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the generated part (remove the original prompt)
answer = full_response[len(prompt):].strip()
# Count generated tokens
generated_tokens = len(tokenizer.encode(answer))
# Get performance metrics from monitor
perf_metrics = monitor.stop_monitoring()
# Calculate and save performance metrics
metrics_data = calculate_performance_metrics(
start_time,
end_time,
prompt_tokens,
generated_tokens,
perf_metrics.peak_memory
)
# Save to history
save_performance_metrics(metrics_data)
# Log performance summary
logger.info(f"Performance Metrics:")
logger.info(f" Model: {metrics_data['model']}")
logger.info(f" Inference Latency: {metrics_data['inference_latency_ms']} ms")
logger.info(f" Throughput: {metrics_data['throughput_tokens_s']} tokens/s")
logger.info(f" RAM Usage: {metrics_data['ram_usage_mb']} MB")
logger.info(f" Tokens (prompt/generated): {metrics_data['prompt_tokens']}/{metrics_data['generated_tokens']}")
if not answer:
answer = "I apologize, but I couldn't generate a proper response. Please try rephrasing your question."
logger.info(f"Generated response length: {len(answer)} characters")
# Add performance info to sources
perf_info = f"\n\n**Performance Metrics:**\n" \
f"- Model: {metrics_data['model']}\n" \
f"- Inference Latency: {metrics_data['inference_latency_ms']} ms\n" \
f"- Throughput: {metrics_data['throughput_tokens_s']} tokens/s\n" \
f"- RAM Usage: {metrics_data['ram_usage_mb']} MB\n" \
f"- Total Inference Time: {metrics_data['total_inference_time_s']} s"
# Return answer and sources if requested
if show_context:
return answer, f"**Sources Used:**\n{sources}\n\n**Context:**\n{context}{perf_info}"
else:
return answer, f"**Sources Used:**\n{sources}{perf_info}"
except Exception as e:
logger.error(f"Error in chat_with_rag: {e}")
return f"Sorry, I encountered an error: {str(e)}", ""
# Gradio Interface
with gr.Blocks(title="Car Manual Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🚗 Car Manual Assistant
Ask questions about car maintenance and operations. Uses **Google Gemma-3-270M** model with RAG from car manuals.
""")
with gr.Row():
with gr.Column(scale=3):
user_input = gr.Textbox(
label="Ask a question about your car",
placeholder="e.g., How do I change the engine oil? What is the recommended tire pressure?",
lines=2
)
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Row():
with gr.Column():
answer_output = gr.Textbox(
label="Answer",
lines=6,
interactive=False
)
sources_output = gr.Markdown(label="Sources")
# Example questions
gr.Examples(
examples=[
"How do I change the engine oil?",
"What is the recommended tire pressure?",
"How to check brake fluid level?",
"When should I replace the air filter?",
"How to jump start the car?",
"What does the check engine light mean?",
"How often should I service my car?",
"How to change a flat tire?"
],
inputs=user_input
)
# Event handlers
submit_btn.click(
chat_with_rag,
inputs=[user_input],
outputs=[answer_output, sources_output]
)
user_input.submit(
chat_with_rag,
inputs=[user_input],
outputs=[answer_output, sources_output]
)
clear_btn.click(
lambda: ("", "", ""),
outputs=[user_input, answer_output, sources_output]
)
# Launch UI
if __name__ == "__main__":
logger.info("Launching Gradio interface...")
logger.info("Model loaded from inference.py - google/gemma-3-270m on CPU")
demo.launch()