Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gradio web application for chatting with 3 persona LoRA adapters. | |
| Personas: Dog, Cat, Bird | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import types | |
| import json | |
| import gc | |
| import gradio as gr | |
| import torch | |
| from pathlib import Path | |
| # Disable torch.compile and prevent bitsandbytes issues | |
| os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
| os.environ["BITSANDBYTES_NOWELCOME"] = "1" | |
| os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1" | |
| # Patch import system to prevent bitsandbytes import | |
| _original_import = __builtins__.__import__ | |
| def _patched_import(name, globals=None, locals=None, fromlist=(), level=0): | |
| if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")): | |
| if name not in sys.modules: | |
| dummy = types.ModuleType(name) | |
| dummy.__version__ = "0.0.0" | |
| dummy.nn = types.ModuleType("nn") | |
| dummy.optim = types.ModuleType("optim") | |
| dummy.cuda_setup = types.ModuleType("cuda_setup") | |
| class DummyLinear8bitLt: | |
| pass | |
| class DummyLinear4bit: | |
| pass | |
| dummy.nn.Linear8bitLt = DummyLinear8bitLt | |
| dummy.nn.Linear4bit = DummyLinear4bit | |
| sys.modules[name] = dummy | |
| sys.modules[f"{name}.nn"] = dummy.nn | |
| sys.modules[f"{name}.optim"] = dummy.optim | |
| sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup | |
| return sys.modules[name] | |
| return _original_import(name, globals, locals, fromlist, level) | |
| if isinstance(__builtins__, dict): | |
| __builtins__["__import__"] = _patched_import | |
| else: | |
| __builtins__.__import__ = _patched_import | |
| # Disable torch.compile | |
| try: | |
| torch._dynamo.config.suppress_errors = True | |
| torch._dynamo.config.disable = True | |
| except: | |
| pass | |
| if hasattr(torch, "compile"): | |
| _original_torch_compile = torch.compile | |
| def _noop_compile(func=None, *args, **kwargs): | |
| if func is not None: | |
| return func | |
| def decorator(f): | |
| return f | |
| return decorator | |
| torch.compile = _noop_compile | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Configuration | |
| BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| ADAPTER_REPO = "Tameem7/Persona-Animal" # Main repository | |
| ADAPTER_SUBDIRS = { | |
| "dog": "dog", | |
| "cat": "cat", | |
| "bird": "bird", | |
| } | |
| # Global variables | |
| base_model = None | |
| base_tokenizer = None | |
| current_persona = None | |
| current_model = None | |
| current_tokenizer = None | |
| current_config = None | |
| def load_base_model(): | |
| """Load the base model and tokenizer (only once).""" | |
| global base_model, base_tokenizer | |
| if base_model is not None: | |
| return base_model, base_tokenizer | |
| print(f"Loading base model: {BASE_MODEL}") | |
| base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if base_tokenizer.pad_token is None: | |
| base_tokenizer.pad_token = base_tokenizer.eos_token | |
| # Determine device and dtype | |
| use_cuda = torch.cuda.is_available() | |
| device = "cuda:0" if use_cuda else "cpu" | |
| dtype = torch.bfloat16 if use_cuda else torch.float32 | |
| if use_cuda: | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| dtype=dtype, | |
| device_map="auto", | |
| ) | |
| else: | |
| print("π» Running on CPU") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| dtype=dtype, | |
| ) | |
| base_model = base_model.to(device) | |
| base_model.eval() | |
| print("β Base model loaded") | |
| return base_model, base_tokenizer | |
| def load_persona_adapter(persona_key: str): | |
| """Load a persona adapter.""" | |
| global current_persona, current_model, current_tokenizer, current_config, base_model, base_tokenizer | |
| # If same persona is already loaded, return | |
| if current_persona == persona_key and current_model is not None: | |
| return current_model, current_tokenizer, current_config | |
| # Load base model if not loaded | |
| if base_model is None: | |
| load_base_model() | |
| # Unload previous adapter | |
| if current_model is not None and current_persona != persona_key: | |
| print(f"Unloading previous adapter: {current_persona}") | |
| del current_model | |
| current_model = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Get adapter subdirectory | |
| adapter_subdir = ADAPTER_SUBDIRS.get(persona_key) | |
| if not adapter_subdir: | |
| raise ValueError(f"Unknown persona: {persona_key}") | |
| print(f"Loading adapter: {ADAPTER_REPO}/{adapter_subdir}") | |
| # Create a fresh copy of base model for the adapter | |
| # (PEFT needs a clean base model) | |
| if current_persona != persona_key: | |
| # Reload base model for new adapter | |
| print(f"Creating base model copy for {persona_key} adapter...") | |
| # Determine device and dtype | |
| use_cuda = torch.cuda.is_available() | |
| device = "cuda:0" if use_cuda else "cpu" | |
| dtype = torch.bfloat16 if use_cuda else torch.float32 | |
| if use_cuda: | |
| base_model_copy = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| dtype=dtype, | |
| device_map="auto", | |
| ) | |
| else: | |
| base_model_copy = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| dtype=dtype, | |
| ) | |
| base_model_copy = base_model_copy.to(device) | |
| # Download adapter files from subdirectory and load | |
| try: | |
| from huggingface_hub import snapshot_download | |
| # Download the adapter files from the subdirectory | |
| print(f"Downloading adapter files from {ADAPTER_REPO}/{adapter_subdir}...") | |
| # Download the entire repo (or use cache if available) | |
| repo_cache_dir = snapshot_download( | |
| repo_id=ADAPTER_REPO, | |
| repo_type="model", | |
| cache_dir=None, # Use default cache | |
| ) | |
| # The adapter files should be in repo_cache_dir/adapter_subdir/ | |
| adapter_local_path = os.path.join(repo_cache_dir, adapter_subdir) | |
| if not os.path.exists(adapter_local_path): | |
| raise FileNotFoundError(f"Adapter subdirectory not found: {adapter_local_path}") | |
| # Check if adapter_config.json exists | |
| adapter_config_path = os.path.join(adapter_local_path, "adapter_config.json") | |
| if not os.path.exists(adapter_config_path): | |
| raise FileNotFoundError(f"adapter_config.json not found in {adapter_local_path}") | |
| # Load adapter from local path | |
| print(f"Loading adapter from: {adapter_local_path}") | |
| current_model = PeftModel.from_pretrained(base_model_copy, adapter_local_path) | |
| current_model.eval() | |
| except Exception as e: | |
| error_msg = f"Failed to load adapter from {ADAPTER_REPO}/{adapter_subdir}: {str(e)}" | |
| print(f"β {error_msg}") | |
| print("Make sure the adapter files are uploaded to the correct subdirectory on Hugging Face.") | |
| raise RuntimeError(error_msg) from e | |
| # Load persona config | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| config_path = hf_hub_download( | |
| repo_id=ADAPTER_REPO, | |
| filename=f"{adapter_subdir}/persona_config.json", | |
| repo_type="model" | |
| ) | |
| with open(config_path, 'r') as f: | |
| current_config = json.load(f) | |
| except Exception as e: | |
| print(f"β οΈ Could not load persona_config.json: {e}") | |
| print("Using default persona config.") | |
| current_config = {"persona_name": persona_key.title(), "persona_description": ""} | |
| current_persona = persona_key | |
| current_tokenizer = base_tokenizer | |
| print(f"β Loaded {persona_key} persona") | |
| return current_model, current_tokenizer, current_config | |
| def generate_response(persona_key: str, message: str, history: list, max_tokens: int = 80): | |
| """Generate a response from the selected persona.""" | |
| global current_model, current_tokenizer, current_config | |
| if not message or not message.strip(): | |
| return history, "" | |
| try: | |
| # Load adapter if needed | |
| model, tokenizer, config = load_persona_adapter(persona_key) | |
| # Build messages with conversation history | |
| system_prompt = "" | |
| if config: | |
| system_prompt = f"You are {config.get('persona_name', '')}. {config.get('persona_description', '')}" | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Add conversation history (last 5 exchanges to avoid too long context) | |
| # Handle both tuple format [user_msg, assistant_msg] and messages format | |
| for item in history[-5:]: | |
| if isinstance(item, dict) and "role" in item: | |
| # Messages format | |
| messages.append(item) | |
| else: | |
| # Tuple format [user_msg, assistant_msg] | |
| user_msg, assistant_msg = item | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Apply chat template | |
| formatted = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| # Tokenize | |
| inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512) | |
| # Move inputs to the same device as the model | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| ) | |
| # Extract only the newly generated tokens | |
| input_length = inputs['input_ids'].shape[1] | |
| generated_tokens = outputs[0][input_length:] | |
| # Decode only the generated part | |
| response = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| # Clean up response | |
| response = response.strip() | |
| if tokenizer.eos_token: | |
| response = response.replace(tokenizer.eos_token, "").strip() | |
| if tokenizer.pad_token: | |
| response = response.replace(tokenizer.pad_token, "").strip() | |
| # Remove chat template artifacts | |
| response = response.replace("<|system|>", "").replace("</|system|>", "") | |
| response = response.replace("<|user|>", "").replace("</|user|>", "") | |
| response = response.replace("<|assistant|>", "").replace("</|assistant|>", "") | |
| response = response.replace("<|", "").replace("|>", "") | |
| # Clean up extra whitespace | |
| response = " ".join(response.split()) | |
| response = response.strip() | |
| # Update history with messages format (Gradio expects this format) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| return history, "" | |
| except Exception as e: | |
| error_msg = f"Error generating response: {str(e)}" | |
| print(error_msg) | |
| return history, error_msg | |
| def clear_chat(): | |
| """Clear the chat history.""" | |
| return [], "" # Empty list for messages format | |
| # Create Gradio interface | |
| with gr.Blocks(title="Persona Chat") as app: | |
| gr.Markdown( | |
| """ | |
| # πΎ Persona Chat - Talk to Animals! | |
| Chat with three different animal personas, each with their own unique personality: | |
| - **π Dog**: Friendly, playful, and enthusiastic | |
| - **π± Cat**: Independent, curious, and sometimes sassy | |
| - **π¦ Bird**: Energetic, talkative, and free-spirited | |
| **π» Running on CPU** - Responses may be slower but will work perfectly! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| persona_dropdown = gr.Dropdown( | |
| choices=["dog", "cat", "bird"], | |
| value="dog", | |
| label="Select Persona" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=20, | |
| maximum=150, | |
| value=80, | |
| step=10, | |
| label="Max Response Length" | |
| ) | |
| clear_btn = gr.Button("Clear Chat") | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| height=400 | |
| ) | |
| msg_input = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| lines=2 | |
| ) | |
| send_btn = gr.Button("Send") | |
| # Event handlers | |
| def chat_fn(persona, message, history, max_tokens): | |
| return generate_response(persona, message, history, max_tokens) | |
| send_btn.click( | |
| fn=chat_fn, | |
| inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider], | |
| outputs=[chatbot, msg_input] | |
| ) | |
| msg_input.submit( | |
| fn=chat_fn, | |
| inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider], | |
| outputs=[chatbot, msg_input] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| outputs=[chatbot, msg_input] | |
| ) | |
| if __name__ == "__main__": | |
| # Load base model first | |
| print("Initializing...") | |
| load_base_model() | |
| app.launch( | |
| server_name="0.0.0.0" if os.getenv("SPACE_ID") else "127.0.0.1", | |
| server_port=int(os.getenv("PORT", 7860)), | |
| share=False | |
| ) | |