Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import psutil | |
| import os | |
| import time | |
| from typing import Dict, Any | |
| import numpy as np | |
| class MemoryTracker: | |
| def get_memory_usage() -> Dict[str, float]: | |
| """Get current memory usage statistics.""" | |
| process = psutil.Process(os.getpid()) | |
| memory_info = process.memory_info() | |
| return { | |
| 'rss': memory_info.rss / (1024 * 1024), # RSS in MB | |
| 'vms': memory_info.vms / (1024 * 1024), # VMS in MB | |
| 'gpu': torch.cuda.memory_allocated() / (1024 * 1024) if torch.cuda.is_available() else 0 # GPU memory in MB | |
| } | |
| def format_memory_stats(stats: Dict[str, float]) -> str: | |
| """Format memory statistics into a readable string.""" | |
| return (f"RSS Memory: {stats['rss']:.2f} MB\n" | |
| f"Virtual Memory: {stats['vms']:.2f} MB\n" | |
| f"GPU Memory: {stats['gpu']:.2f} MB") | |
| class CustomerSupportBot: | |
| def __init__(self, model_path="models/customer_support_gpt"): | |
| """ | |
| Initialize the customer support bot with the fine-tuned model and memory tracking. | |
| Args: | |
| model_path (str): Path to the saved model and tokenizer | |
| """ | |
| # Record initial memory state | |
| self.initial_memory = MemoryTracker.get_memory_usage() | |
| # Load tokenizer and track memory | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.post_tokenizer_memory = MemoryTracker.get_memory_usage() | |
| # Load model and track memory | |
| self.model = AutoModelForCausalLM.from_pretrained(model_path) | |
| self.post_model_memory = MemoryTracker.get_memory_usage() | |
| # Move model to GPU if available | |
| self.device = "cpu"#"cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = self.model.to(self.device) | |
| self.post_device_memory = MemoryTracker.get_memory_usage() | |
| # Calculate memory deltas | |
| self.memory_deltas = { | |
| 'tokenizer_load': {k: self.post_tokenizer_memory[k] - self.initial_memory[k] | |
| for k in self.initial_memory}, | |
| 'model_load': {k: self.post_model_memory[k] - self.post_tokenizer_memory[k] | |
| for k in self.initial_memory}, | |
| 'device_transfer': {k: self.post_device_memory[k] - self.post_model_memory[k] | |
| for k in self.initial_memory} | |
| } | |
| # Initialize inference memory tracking | |
| self.inference_memory_stats = [] | |
| def get_memory_report(self) -> str: | |
| """Generate a comprehensive memory usage report.""" | |
| report = ["Memory Usage Report:"] | |
| report.append("\nModel Loading Memory Changes:") | |
| report.append("Tokenizer Loading:") | |
| report.append(MemoryTracker.format_memory_stats(self.memory_deltas['tokenizer_load'])) | |
| report.append("\nModel Loading:") | |
| report.append(MemoryTracker.format_memory_stats(self.memory_deltas['model_load'])) | |
| report.append("\nDevice Transfer:") | |
| report.append(MemoryTracker.format_memory_stats(self.memory_deltas['device_transfer'])) | |
| if self.inference_memory_stats: | |
| avg_inference_memory = { | |
| k: np.mean([stats[k] for stats in self.inference_memory_stats]) | |
| for k in self.inference_memory_stats[0] | |
| } | |
| report.append("\nAverage Inference Memory Usage:") | |
| report.append(MemoryTracker.format_memory_stats(avg_inference_memory)) | |
| return "\n".join(report) | |
| def generate_response(self, instruction, max_length=100, temperature=0.7): | |
| """ | |
| Generate a response for a given customer support instruction/query with memory tracking. | |
| Args: | |
| instruction (str): Customer's query or instruction | |
| max_length (int): Maximum length of the generated response | |
| temperature (float): Controls randomness in generation | |
| Returns: | |
| tuple: (Generated response, Memory usage statistics) | |
| """ | |
| # Record pre-inference memory | |
| pre_inference_memory = MemoryTracker.get_memory_usage() | |
| # Format and tokenize input | |
| input_text = f"Instruction: {instruction}\nResponse:" | |
| inputs = self.tokenizer(input_text, return_tensors="pt") | |
| inputs = inputs.to(self.device) | |
| # Generate response and track memory | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=temperature, | |
| num_return_sequences=1, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50 | |
| ) | |
| inference_time = time.time() - start_time | |
| # Record post-inference memory | |
| post_inference_memory = MemoryTracker.get_memory_usage() | |
| # Calculate memory delta for this inference | |
| inference_memory_delta = { | |
| k: post_inference_memory[k] - pre_inference_memory[k] | |
| for k in pre_inference_memory | |
| } | |
| self.inference_memory_stats.append(inference_memory_delta) | |
| # Decode response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.split("Response:")[-1].strip() | |
| return response, { | |
| 'memory_delta': inference_memory_delta, | |
| 'inference_time': inference_time | |
| } | |
| def main(): | |
| # Initialize the bot | |
| print("Initializing bot and tracking memory usage...") | |
| bot = CustomerSupportBot() | |
| print(bot.get_memory_report()) | |
| # Example queries | |
| example_queries = [ | |
| "How do I reset my password?", | |
| "What are your shipping policies?", | |
| "I want to return a product.", | |
| ] | |
| # Generate and print responses with memory stats | |
| print("\nCustomer Support Bot Demo:\n") | |
| for query in example_queries: | |
| print(f"Customer: {query}") | |
| response, stats = bot.generate_response(query) | |
| print(f"Bot: {response}") | |
| print(f"Inference Memory Delta: {MemoryTracker.format_memory_stats(stats['memory_delta'])}") | |
| print(f"Inference Time: {stats['inference_time']:.2f} seconds\n") | |
| # Interactive mode | |
| print("Enter your questions (type 'quit' to exit):") | |
| while True: | |
| query = input("\nYour question: ") | |
| if query.lower() == 'quit': | |
| break | |
| response, stats = bot.generate_response(query) | |
| print(f"Bot: {response}") | |
| print(f"Inference Memory Delta: {MemoryTracker.format_memory_stats(stats['memory_delta'])}") | |
| print(f"Inference Time: {stats['inference_time']:.2f} seconds") | |
| # Print final memory report | |
| print("\nFinal Memory Report:") | |
| print(bot.get_memory_report()) | |
| if __name__ == "__main__": | |
| main() |