Spaces:
Build error
Build error
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| import torch | |
| from typing import List, Dict, Any | |
| import gc | |
| import psutil | |
| class ResponseGenerator: | |
| def __init__(self, model_name: str = "microsoft/phi-2"): | |
| """ | |
| Initialize the response generator with an LLM | |
| """ | |
| print(f"Loading LLM: {model_name}") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| gc.collect() | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| try: | |
| if self.device == "cuda": | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| available_memory = psutil.virtual_memory().total / (1024 ** 3) | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) | |
| max_memory = {0: f"{min(gpu_memory, 15)}GiB", "cpu": f"{min(available_memory, 30)}GiB"} | |
| print(f"Setting max_memory: {max_memory}") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| max_memory=max_memory, | |
| offload_folder="offload", | |
| offload_state_dict=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| else: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map={"": "cpu"}, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| except Exception as e: | |
| print(f"Model loading error: {e}") | |
| print("Falling back to TinyLlama...") | |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map={"": self.device}, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| print("LLM loaded successfully") | |
| def generate_response(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str: | |
| """ | |
| Generate a response using the LLM with retrieved FAQs as context | |
| """ | |
| prompt = self._create_prompt(query, relevant_faqs) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response[len(prompt):].strip() | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| return response | |
| def _create_prompt(self, query: str, relevant_faqs: List[Dict[str, Any]]) -> str: | |
| """ | |
| Create a prompt for the LLM with retrieved FAQs as context | |
| """ | |
| faq_context = "\n\n".join([f"Q: {faq['question']}\nA: {faq['answer']}" for faq in relevant_faqs]) | |
| prompt = f""" | |
| Below are some relevant e-commerce customer support FAQ entries: | |
| {faq_context} | |
| Based on the information above, provide a helpful, accurate, and concise response to the following customer query: | |
| Customer Query: {query} | |
| Response: | |
| """ | |
| return prompt |