File size: 10,742 Bytes
092f28b
f05e8f9
 
 
 
5262791
 
 
f05e8f9
 
 
 
 
 
 
 
5262791
f05e8f9
 
 
 
 
 
5262791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05e8f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5262791
 
 
 
f05e8f9
 
5262791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05e8f9
 
 
 
 
 
 
5262791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05e8f9
 
 
 
 
5262791
 
 
 
 
 
 
 
f05e8f9
 
5262791
f05e8f9
5262791
f05e8f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f13ec
092f28b
f05e8f9
092f28b
f05e8f9
 
 
10d036a
f05e8f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import gradio as gr
import torch
import sys
import os
import re
import json
import time
from datetime import datetime
from pathlib import Path

# Add the project root to Python path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from src.inference.inference import tokenizer, model  # Import from your inference.py
from src.vector_db.manager import ChromaVectorDBManager
from src.utils.performance import PerformanceMonitor
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Performance history file
PERFORMANCE_HISTORY_FILE = Path("performance_history.json")

def save_performance_metrics(metrics_data):
    """Save performance metrics to history file"""
    try:
        if PERFORMANCE_HISTORY_FILE.exists():
            with open(PERFORMANCE_HISTORY_FILE, 'r') as f:
                history = json.load(f)
        else:
            history = []
        
        history.append(metrics_data)
        
        with open(PERFORMANCE_HISTORY_FILE, 'w') as f:
            json.dump(history, f, indent=2)
            
    except Exception as e:
        logger.error(f"Failed to save performance metrics: {e}")

def calculate_performance_metrics(start_time, end_time, prompt_tokens, generated_tokens, peak_memory_mb):
    """Calculate performance metrics similar to the requested format"""
    inference_time = end_time - start_time
    total_tokens = prompt_tokens + generated_tokens
    
    # Calculate throughput (tokens per second)
    throughput = total_tokens / inference_time if inference_time > 0 else 0
    
    # Calculate inference latency (time per token in milliseconds)
    latency_ms = (inference_time * 1000) / total_tokens if total_tokens > 0 else 0
    
    return {
        "timestamp": datetime.now().isoformat(),
        "model": "Gemma-3-270M",
        "load_time_s": "N/A",  # Model is already loaded
        "inference_latency_ms": round(latency_ms, 2),
        "throughput_tokens_s": round(throughput, 2),
        "ram_usage_mb": round(peak_memory_mb, 2),
        "vram_usage_mb": 0,  # CPU-only model
        "energy_j": "N/A",  # Would require specialized monitoring
        "prompt_tokens": prompt_tokens,
        "generated_tokens": generated_tokens,
        "total_inference_time_s": round(inference_time, 3)
    }

# Initialize Vector DB Manager
try:
    logger.info("Initializing ChromaDB manager")
    db_manager = ChromaVectorDBManager()
    
    # Check if collection has data
    stats = db_manager.get_collection_stats()
    logger.info(f"Database stats: {stats}")
    
    if stats.get('total_chunks', 0) == 0:
        logger.warning("No chunks found in database. Processing chunks...")
        success = db_manager.process_all_chunks()
        if not success:
            logger.error("Failed to process chunks")
        else:
            logger.info("Chunks processed successfully")
            
except Exception as e:
    logger.error(f"Failed to initialize vector database: {e}")
    raise

def chat_with_rag(user_query, show_context=False):
    """Chat function with RAG support using your existing model setup."""
    try:
        if not user_query.strip():
            return "Please enter a question.", ""
        
        logger.info(f"Processing query: {user_query}")
        
        # Get top-k relevant chunks
        results = db_manager.search_for_rag(
            user_query, 
            n_results=3,
            use_truncated=True,
            filter_128_context=True
        )
        
        if not results or results[0]['score'] < 0.5:
            return "I can only answer questions based on the provided car manuals. Please ask a question related to car maintenance or operation.", ""
        
        # Build context with source information
        context_parts = []
        source_info = []
        
        for i, result in enumerate(results, 1):
            context_parts.append(result["text"])
            source_info.append(f"Source {i}: {result['source_file']} (Score: {result['score']:.3f})")
        
        context = "\n\n".join(context_parts)
        
        # Clean the context before feeding it to the model
        cleaned_context = re.sub(r'(\s*\.\s*){3,}', ' ', context) # Remove long series of dots
        cleaned_context = re.sub(r'\s+', ' ', cleaned_context).strip() # Normalize whitespace
        
        sources = "\n".join(source_info)
        
        # Build a more conversational and helpful prompt
        prompt = f"""You are a specialized car manual assistant. Your sole purpose is to answer questions based ONLY on the provided text from car manuals.

Strictly follow these rules:
1.  Base your answer *exclusively* on the "CONTEXT" provided below.
2.  Synthesize a complete and coherent answer. Do not repeat fragments of the context.
3.  If the answer is not found in the CONTEXT, you MUST state: "I'm sorry, but the answer to your question is not available in the provided car manuals."
4.  Do not use any external knowledge or make up information.

CONTEXT:
--- START OF CONTEXT ---
{cleaned_context}
--- END OF CONTEXT ---

QUESTION:
{user_query}

ANSWER:"""
        
        # Count prompt tokens
        prompt_tokens = len(tokenizer.encode(prompt))
        
        # Use inference setup with performance monitoring
        inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
        
        # Start performance monitoring for inference
        with PerformanceMonitor("Model_Inference") as monitor:
            start_time = time.time()
            
            # Generate response with conservative parameters for gemma-3-270m
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=256,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    repetition_penalty=1.1,
                    pad_token_id=tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
            
            end_time = time.time()
        
        # Decode and clean response
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the generated part (remove the original prompt)
        answer = full_response[len(prompt):].strip()
        
        # Count generated tokens
        generated_tokens = len(tokenizer.encode(answer))
        
        # Get performance metrics from monitor
        perf_metrics = monitor.stop_monitoring()
        
        # Calculate and save performance metrics
        metrics_data = calculate_performance_metrics(
            start_time, 
            end_time, 
            prompt_tokens, 
            generated_tokens, 
            perf_metrics.peak_memory
        )
        
        # Save to history
        save_performance_metrics(metrics_data)
        
        # Log performance summary
        logger.info(f"Performance Metrics:")
        logger.info(f"  Model: {metrics_data['model']}")
        logger.info(f"  Inference Latency: {metrics_data['inference_latency_ms']} ms")
        logger.info(f"  Throughput: {metrics_data['throughput_tokens_s']} tokens/s")
        logger.info(f"  RAM Usage: {metrics_data['ram_usage_mb']} MB")
        logger.info(f"  Tokens (prompt/generated): {metrics_data['prompt_tokens']}/{metrics_data['generated_tokens']}")
        
        if not answer:
            answer = "I apologize, but I couldn't generate a proper response. Please try rephrasing your question."
        
        logger.info(f"Generated response length: {len(answer)} characters")
        
        # Add performance info to sources
        perf_info = f"\n\n**Performance Metrics:**\n" \
                   f"- Model: {metrics_data['model']}\n" \
                   f"- Inference Latency: {metrics_data['inference_latency_ms']} ms\n" \
                   f"- Throughput: {metrics_data['throughput_tokens_s']} tokens/s\n" \
                   f"- RAM Usage: {metrics_data['ram_usage_mb']} MB\n" \
                   f"- Total Inference Time: {metrics_data['total_inference_time_s']} s"
        
        # Return answer and sources if requested
        if show_context:
            return answer, f"**Sources Used:**\n{sources}\n\n**Context:**\n{context}{perf_info}"
        else:
            return answer, f"**Sources Used:**\n{sources}{perf_info}"
            
    except Exception as e:
        logger.error(f"Error in chat_with_rag: {e}")
        return f"Sorry, I encountered an error: {str(e)}", ""

# Gradio Interface
with gr.Blocks(title="Car Manual Assistant", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🚗 Car Manual Assistant
    
    Ask questions about car maintenance and operations. Uses **Google Gemma-3-270M** model with RAG from car manuals.
    """)
    
    with gr.Row():
        with gr.Column(scale=3):
            user_input = gr.Textbox(
                label="Ask a question about your car",
                placeholder="e.g., How do I change the engine oil? What is the recommended tire pressure?",
                lines=2
            )
            
            with gr.Row():
                submit_btn = gr.Button("Submit", variant="primary")
                clear_btn = gr.Button("Clear", variant="secondary")
    
    with gr.Row():
        with gr.Column():
            answer_output = gr.Textbox(
                label="Answer",
                lines=6,
                interactive=False
            )
            sources_output = gr.Markdown(label="Sources")
    
    # Example questions
    gr.Examples(
        examples=[
            "How do I change the engine oil?",
            "What is the recommended tire pressure?",
            "How to check brake fluid level?",
            "When should I replace the air filter?",
            "How to jump start the car?",
            "What does the check engine light mean?",
            "How often should I service my car?",
            "How to change a flat tire?"
        ],
        inputs=user_input
    )
    
    # Event handlers
    submit_btn.click(
        chat_with_rag,
        inputs=[user_input],
        outputs=[answer_output, sources_output]
    )
    
    user_input.submit(
        chat_with_rag,
        inputs=[user_input],
        outputs=[answer_output, sources_output]
    )
    
    clear_btn.click(
        lambda: ("", "", ""),
        outputs=[user_input, answer_output, sources_output]
    )

# Launch UI
if __name__ == "__main__":
    logger.info("Launching Gradio interface...")
    logger.info("Model loaded from inference.py - google/gemma-3-270m on CPU")
    
    demo.launch()