#!/usr/bin/env python3 """ Gradio web application for chatting with 3 persona LoRA adapters. Personas: Dog, Cat, Bird """ from __future__ import annotations import os import sys import types import json import gc import gradio as gr import torch from pathlib import Path # Disable torch.compile and prevent bitsandbytes issues os.environ["TORCH_COMPILE_DISABLE"] = "1" os.environ["BITSANDBYTES_NOWELCOME"] = "1" os.environ["DISABLE_BITSANDBYTES_AUTO_INSTALL"] = "1" # Patch import system to prevent bitsandbytes import _original_import = __builtins__.__import__ def _patched_import(name, globals=None, locals=None, fromlist=(), level=0): if name == "bitsandbytes" or (name and name.startswith("bitsandbytes")): if name not in sys.modules: dummy = types.ModuleType(name) dummy.__version__ = "0.0.0" dummy.nn = types.ModuleType("nn") dummy.optim = types.ModuleType("optim") dummy.cuda_setup = types.ModuleType("cuda_setup") class DummyLinear8bitLt: pass class DummyLinear4bit: pass dummy.nn.Linear8bitLt = DummyLinear8bitLt dummy.nn.Linear4bit = DummyLinear4bit sys.modules[name] = dummy sys.modules[f"{name}.nn"] = dummy.nn sys.modules[f"{name}.optim"] = dummy.optim sys.modules[f"{name}.cuda_setup"] = dummy.cuda_setup return sys.modules[name] return _original_import(name, globals, locals, fromlist, level) if isinstance(__builtins__, dict): __builtins__["__import__"] = _patched_import else: __builtins__.__import__ = _patched_import # Disable torch.compile try: torch._dynamo.config.suppress_errors = True torch._dynamo.config.disable = True except: pass if hasattr(torch, "compile"): _original_torch_compile = torch.compile def _noop_compile(func=None, *args, **kwargs): if func is not None: return func def decorator(f): return f return decorator torch.compile = _noop_compile from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer # Configuration BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ADAPTER_REPO = "Tameem7/Persona-Animal" # Main repository ADAPTER_SUBDIRS = { "dog": "dog", "cat": "cat", "bird": "bird", } # Global variables base_model = None base_tokenizer = None current_persona = None current_model = None current_tokenizer = None current_config = None def load_base_model(): """Load the base model and tokenizer (only once).""" global base_model, base_tokenizer if base_model is not None: return base_model, base_tokenizer print(f"Loading base model: {BASE_MODEL}") base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if base_tokenizer.pad_token is None: base_tokenizer.pad_token = base_tokenizer.eos_token # Determine device and dtype use_cuda = torch.cuda.is_available() device = "cuda:0" if use_cuda else "cpu" dtype = torch.bfloat16 if use_cuda else torch.float32 if use_cuda: base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, dtype=dtype, device_map="auto", ) else: print("💻 Running on CPU") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, dtype=dtype, ) base_model = base_model.to(device) base_model.eval() print("✅ Base model loaded") return base_model, base_tokenizer def load_persona_adapter(persona_key: str): """Load a persona adapter.""" global current_persona, current_model, current_tokenizer, current_config, base_model, base_tokenizer # If same persona is already loaded, return if current_persona == persona_key and current_model is not None: return current_model, current_tokenizer, current_config # Load base model if not loaded if base_model is None: load_base_model() # Unload previous adapter if current_model is not None and current_persona != persona_key: print(f"Unloading previous adapter: {current_persona}") del current_model current_model = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Get adapter subdirectory adapter_subdir = ADAPTER_SUBDIRS.get(persona_key) if not adapter_subdir: raise ValueError(f"Unknown persona: {persona_key}") print(f"Loading adapter: {ADAPTER_REPO}/{adapter_subdir}") # Create a fresh copy of base model for the adapter # (PEFT needs a clean base model) if current_persona != persona_key: # Reload base model for new adapter print(f"Creating base model copy for {persona_key} adapter...") # Determine device and dtype use_cuda = torch.cuda.is_available() device = "cuda:0" if use_cuda else "cpu" dtype = torch.bfloat16 if use_cuda else torch.float32 if use_cuda: base_model_copy = AutoModelForCausalLM.from_pretrained( BASE_MODEL, dtype=dtype, device_map="auto", ) else: base_model_copy = AutoModelForCausalLM.from_pretrained( BASE_MODEL, dtype=dtype, ) base_model_copy = base_model_copy.to(device) # Download adapter files from subdirectory and load try: from huggingface_hub import snapshot_download # Download the adapter files from the subdirectory print(f"Downloading adapter files from {ADAPTER_REPO}/{adapter_subdir}...") # Download the entire repo (or use cache if available) repo_cache_dir = snapshot_download( repo_id=ADAPTER_REPO, repo_type="model", cache_dir=None, # Use default cache ) # The adapter files should be in repo_cache_dir/adapter_subdir/ adapter_local_path = os.path.join(repo_cache_dir, adapter_subdir) if not os.path.exists(adapter_local_path): raise FileNotFoundError(f"Adapter subdirectory not found: {adapter_local_path}") # Check if adapter_config.json exists adapter_config_path = os.path.join(adapter_local_path, "adapter_config.json") if not os.path.exists(adapter_config_path): raise FileNotFoundError(f"adapter_config.json not found in {adapter_local_path}") # Load adapter from local path print(f"Loading adapter from: {adapter_local_path}") current_model = PeftModel.from_pretrained(base_model_copy, adapter_local_path) current_model.eval() except Exception as e: error_msg = f"Failed to load adapter from {ADAPTER_REPO}/{adapter_subdir}: {str(e)}" print(f"❌ {error_msg}") print("Make sure the adapter files are uploaded to the correct subdirectory on Hugging Face.") raise RuntimeError(error_msg) from e # Load persona config try: from huggingface_hub import hf_hub_download config_path = hf_hub_download( repo_id=ADAPTER_REPO, filename=f"{adapter_subdir}/persona_config.json", repo_type="model" ) with open(config_path, 'r') as f: current_config = json.load(f) except Exception as e: print(f"⚠️ Could not load persona_config.json: {e}") print("Using default persona config.") current_config = {"persona_name": persona_key.title(), "persona_description": ""} current_persona = persona_key current_tokenizer = base_tokenizer print(f"✅ Loaded {persona_key} persona") return current_model, current_tokenizer, current_config def generate_response(persona_key: str, message: str, history: list, max_tokens: int = 80): """Generate a response from the selected persona.""" global current_model, current_tokenizer, current_config if not message or not message.strip(): return history, "" try: # Load adapter if needed model, tokenizer, config = load_persona_adapter(persona_key) # Build messages with conversation history system_prompt = "" if config: system_prompt = f"You are {config.get('persona_name', '')}. {config.get('persona_description', '')}" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) # Add conversation history (last 5 exchanges to avoid too long context) # Handle both tuple format [user_msg, assistant_msg] and messages format for item in history[-5:]: if isinstance(item, dict) and "role" in item: # Messages format messages.append(item) else: # Tuple format [user_msg, assistant_msg] user_msg, assistant_msg = item messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Apply chat template formatted = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512) # Move inputs to the same device as the model device = next(model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.2, no_repeat_ngram_size=3, ) # Extract only the newly generated tokens input_length = inputs['input_ids'].shape[1] generated_tokens = outputs[0][input_length:] # Decode only the generated part response = tokenizer.decode(generated_tokens, skip_special_tokens=True) # Clean up response response = response.strip() if tokenizer.eos_token: response = response.replace(tokenizer.eos_token, "").strip() if tokenizer.pad_token: response = response.replace(tokenizer.pad_token, "").strip() # Remove chat template artifacts response = response.replace("<|system|>", "").replace("", "") response = response.replace("<|user|>", "").replace("", "") response = response.replace("<|assistant|>", "").replace("", "") response = response.replace("<|", "").replace("|>", "") # Clean up extra whitespace response = " ".join(response.split()) response = response.strip() # Update history with messages format (Gradio expects this format) history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) return history, "" except Exception as e: error_msg = f"Error generating response: {str(e)}" print(error_msg) return history, error_msg def clear_chat(): """Clear the chat history.""" return [], "" # Empty list for messages format # Create Gradio interface with gr.Blocks(title="Persona Chat") as app: gr.Markdown( """ # 🐾 Persona Chat - Talk to Animals! Chat with three different animal personas, each with their own unique personality: - **🐕 Dog**: Friendly, playful, and enthusiastic - **🐱 Cat**: Independent, curious, and sometimes sassy - **🐦 Bird**: Energetic, talkative, and free-spirited **💻 Running on CPU** - Responses may be slower but will work perfectly! """ ) with gr.Row(): with gr.Column(scale=1): persona_dropdown = gr.Dropdown( choices=["dog", "cat", "bird"], value="dog", label="Select Persona" ) max_tokens_slider = gr.Slider( minimum=20, maximum=150, value=80, step=10, label="Max Response Length" ) clear_btn = gr.Button("Clear Chat") with gr.Column(scale=3): chatbot = gr.Chatbot( label="Chat", height=400 ) msg_input = gr.Textbox( label="Your Message", placeholder="Type your message here...", lines=2 ) send_btn = gr.Button("Send") # Event handlers def chat_fn(persona, message, history, max_tokens): return generate_response(persona, message, history, max_tokens) send_btn.click( fn=chat_fn, inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider], outputs=[chatbot, msg_input] ) msg_input.submit( fn=chat_fn, inputs=[persona_dropdown, msg_input, chatbot, max_tokens_slider], outputs=[chatbot, msg_input] ) clear_btn.click( fn=clear_chat, outputs=[chatbot, msg_input] ) if __name__ == "__main__": # Load base model first print("Initializing...") load_base_model() app.launch( server_name="0.0.0.0" if os.getenv("SPACE_ID") else "127.0.0.1", server_port=int(os.getenv("PORT", 7860)), share=False )