Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import time | |
| import gc | |
| import os | |
| import psutil | |
| # Configuration | |
| BASE_MODEL = "microsoft/phi-2" | |
| ADAPTER_MODEL = "pradeep6kumar2024/phi2-qlora-assistant" | |
| DEBUG = False # Set to True to enable debug prints | |
| # Memory monitoring | |
| def get_memory_usage(): | |
| process = psutil.Process(os.getpid()) | |
| return process.memory_info().rss / (1024 * 1024) # MB | |
| class ModelWrapper: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.loaded = False | |
| def load_model(self): | |
| if not self.loaded: | |
| try: | |
| # Force CPU usage | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| device = torch.device("cpu") | |
| # Clear memory | |
| gc.collect() | |
| if DEBUG: | |
| print(f"Memory before loading: {get_memory_usage():.2f} MB") | |
| print("Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| BASE_MODEL, | |
| trust_remote_code=True, | |
| padding_side="left" | |
| ) | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| if DEBUG: | |
| print(f"Memory after tokenizer: {get_memory_usage():.2f} MB") | |
| print("Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| use_flash_attention_2=False, | |
| low_cpu_mem_usage=True, | |
| offload_folder="offload" | |
| ) | |
| if DEBUG: | |
| print(f"Memory after base model: {get_memory_usage():.2f} MB") | |
| print("Loading LoRA adapter...") | |
| self.model = PeftModel.from_pretrained( | |
| base_model, | |
| ADAPTER_MODEL, | |
| torch_dtype=torch.float32, | |
| device_map="cpu" | |
| ) | |
| # Free up memory | |
| del base_model | |
| gc.collect() | |
| if DEBUG: | |
| print(f"Memory after adapter: {get_memory_usage():.2f} MB") | |
| self.model.eval() | |
| print("Model loading complete!") | |
| self.loaded = True | |
| except Exception as e: | |
| print(f"Error during model loading: {str(e)}") | |
| raise | |
| def generate_response(self, prompt, max_length=256, temperature=0.7, top_p=0.9): | |
| if not self.loaded: | |
| self.load_model() | |
| try: | |
| # Use shorter prompts to save memory | |
| if "function" in prompt.lower() and "python" in prompt.lower(): | |
| enhanced_prompt = f"""Write Python function: {prompt}""" | |
| elif any(word in prompt.lower() for word in ["explain", "what is", "how does", "describe"]): | |
| enhanced_prompt = f"""Explain briefly: {prompt}""" | |
| else: | |
| enhanced_prompt = prompt | |
| if DEBUG: | |
| print(f"Enhanced prompt: {enhanced_prompt}") | |
| # Tokenize input with shorter max length | |
| inputs = self.tokenizer( | |
| enhanced_prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256, # Reduced for memory | |
| padding=True | |
| ).to("cpu") | |
| # Generate with minimal parameters | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=min(max_length, 256), # Strict limit | |
| min_length=10, # Reduced minimum | |
| temperature=min(0.5, temperature), | |
| top_p=min(0.85, top_p), | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| num_return_sequences=1, | |
| early_stopping=True, | |
| num_beams=1, # Greedy decoding to save memory | |
| length_penalty=0.6 | |
| ) | |
| # Decode response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if DEBUG: | |
| print(f"Raw response: {response}") | |
| # Clean up the response | |
| if response.startswith(enhanced_prompt): | |
| response = response[len(enhanced_prompt):].strip() | |
| if DEBUG: | |
| print(f"After prompt removal: {response}") | |
| # Basic cleanup only | |
| cleaned_response = response.replace("Human:", "").replace("Assistant:", "") | |
| if DEBUG and cleaned_response != response: | |
| print(f"After conversation removal: {cleaned_response}") | |
| response = cleaned_response | |
| # Ensure code examples are properly formatted | |
| if "```python" not in response and "def " in response: | |
| response = "```python\n" + response + "\n```" | |
| # Simple validation | |
| if len(response.strip()) < 10: | |
| if DEBUG: | |
| print("Response validation failed - using fallback") | |
| if "function" in prompt.lower(): | |
| fallback_response = """```python | |
| def add_numbers(a, b): | |
| return a + b | |
| ```""" | |
| else: | |
| fallback_response = "I apologize, but I couldn't generate a response. Please try with a simpler prompt." | |
| response = fallback_response | |
| # Clear memory after generation | |
| gc.collect() | |
| generation_time = time.time() - start_time | |
| return response, generation_time | |
| except Exception as e: | |
| print(f"Error during generation: {str(e)}") | |
| raise | |
| # Initialize model wrapper | |
| model_wrapper = ModelWrapper() | |
| def generate_text(prompt, max_length=256, temperature=0.5, top_p=0.85): | |
| """Gradio interface function""" | |
| try: | |
| if not prompt.strip(): | |
| return "Please enter a prompt." | |
| response, gen_time = model_wrapper.generate_response( | |
| prompt, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| return f"Generated in {gen_time:.2f} seconds:\n\n{response}" | |
| except Exception as e: | |
| print(f"Error in generate_text: {str(e)}") | |
| return f"Error generating response: {str(e)}\nPlease try again with a shorter prompt." | |
| # Create a very lightweight Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="Type your prompt here...", | |
| lines=3 | |
| ), | |
| gr.Slider( | |
| minimum=64, | |
| maximum=256, | |
| value=192, | |
| step=32, | |
| label="Maximum Length", | |
| info="Keep this low for CPU" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=0.7, | |
| value=0.4, | |
| step=0.1, | |
| label="Temperature", | |
| info="Lower is better for CPU" | |
| ), | |
| gr.Slider( | |
| minimum=0.5, | |
| maximum=0.9, | |
| value=0.8, | |
| step=0.1, | |
| label="Top P", | |
| info="Controls diversity" | |
| ), | |
| ], | |
| outputs=gr.Textbox(label="Generated Response", lines=6), | |
| title="Phi-2 QLoRA Assistant (CPU-Optimized)", | |
| description="""This is a lightweight CPU version of the fine-tuned Phi-2 model. | |
| Tips: | |
| - Keep prompts short and specific | |
| - Use lower maximum length (128-192) for faster responses | |
| - Use lower temperature (0.3-0.5) for more reliable responses | |
| """, | |
| examples=[ | |
| [ | |
| "Write a Python function to calculate factorial", | |
| 192, | |
| 0.4, | |
| 0.8 | |
| ], | |
| [ | |
| "Explain machine learning simply", | |
| 192, | |
| 0.4, | |
| 0.8 | |
| ], | |
| [ | |
| "Write a short email to schedule a meeting", | |
| 192, | |
| 0.4, | |
| 0.8 | |
| ] | |
| ], | |
| cache_examples=False, | |
| concurrency_limit=1 # Use the correct parameter for limiting concurrency | |
| ) | |
| if __name__ == "__main__": | |
| # Using the modern approach without queue method | |
| demo.launch(max_threads=1) # Limit the number of worker threads |