Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from peft import PeftModel | |
| import threading | |
| import os | |
| # Check if CUDA is available, otherwise use CPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model_and_tokenizer(model_path="./final_model"): | |
| """Load the fine-tuned phi-2 model and tokenizer""" | |
| print(f"Loading fine-tuned model from {model_path}...") | |
| print(f"Using device: {device}") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "microsoft/phi-2", | |
| trust_remote_code=True | |
| ) | |
| # Set pad_token to a different value than eos_token to fix attention mask issue | |
| if tokenizer.pad_token is None or tokenizer.pad_token == tokenizer.eos_token: | |
| tokenizer.pad_token = tokenizer.unk_token | |
| # For CPU environments, we can't use 4-bit quantization | |
| if device.type == "cuda": | |
| # Use 4-bit quantization on GPU | |
| from transformers import BitsAndBytesConfig | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| # Load base model with 4-bit quantization | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/phi-2", | |
| quantization_config=bnb_config, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| else: | |
| # On CPU, load in 8-bit or full precision | |
| print("Loading on CPU - using 8-bit quantization or full precision") | |
| try: | |
| # Try 8-bit first (requires bitsandbytes) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/phi-2", | |
| load_in_8bit=True, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| except: | |
| # Fall back to full precision if 8-bit fails | |
| print("8-bit loading failed, falling back to full precision (fp32)") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/phi-2", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32 | |
| ) | |
| # Load the fine-tuned LoRA adapter | |
| try: | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| model_path, | |
| device_map="auto" if device.type == "cuda" else None | |
| ) | |
| except Exception as e: | |
| print(f"Error loading LoRA adapter: {e}") | |
| print("Falling back to base model") | |
| model = base_model | |
| # Move model to CPU if needed | |
| if device.type == "cpu": | |
| model = model.to(device) | |
| model.eval() # Set model to evaluation mode | |
| print(f"Fine-tuned model loaded successfully!") | |
| return model, tokenizer | |
| def format_chat_history(messages): | |
| """Format the chat history into a prompt for the model""" | |
| formatted_prompt = "" | |
| for message in messages: | |
| role = message["role"] | |
| content = message["content"] | |
| if role == "user": | |
| formatted_prompt += f"Human: {content}\n\n" | |
| elif role == "assistant": | |
| formatted_prompt += f"Assistant: {content}\n\n" | |
| # Add the final assistant prompt | |
| formatted_prompt += "Assistant:" | |
| return formatted_prompt | |
| def generate_response(model, tokenizer, messages): | |
| """Generate a streaming response from the model based on chat history""" | |
| # Format the conversation history | |
| prompt = format_chat_history(messages) | |
| # Tokenize input | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Create a streamer for token-by-token generation | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # Set generation parameters | |
| generation_kwargs = { | |
| "input_ids": inputs["input_ids"], | |
| "attention_mask": inputs["attention_mask"], | |
| "max_new_tokens": 500, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "do_sample": True, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| "eos_token_id": tokenizer.eos_token_id, | |
| "streamer": streamer, | |
| } | |
| # Start generation in a separate thread | |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream tokens as they're generated | |
| generated_text = "" | |
| for new_text in streamer: | |
| # Check if the model is trying to start a new turn | |
| if "Human:" in new_text or "\nHuman:" in generated_text + new_text: | |
| # Stop generation if model tries to create a new human turn | |
| break | |
| if "Assistant:" in new_text and generated_text: | |
| # Stop if model tries to create a new assistant turn | |
| break | |
| yield new_text | |
| generated_text += new_text |