Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from huggingface_hub import login | |
| # Log in to Hugging Face Hub | |
| login(token=os.environ.get("HF_AUTH_TOKEN", "YOUR_HUGGING_FACE_TOKEN")) | |
| # Load the model and tokenizer | |
| def load_model(): | |
| print("Loading model...") | |
| base_model_name = "mistralai/Mistral-7B-v0.1" | |
| adapter_model_name = "Psalms23Wave/Alkebulan-AI" | |
| # Use CUDA if available, otherwise CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Free up memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Load the base model | |
| print(f"Loading base model: {base_model_name}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| token=os.environ.get("HF_AUTH_TOKEN", "YOUR_HUGGING_FACE_TOKEN"), | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True, | |
| device_map="auto" if device == "cuda" else None | |
| ) | |
| # Load adapter weights | |
| try: | |
| print(f"Loading adapter: {adapter_model_name}") | |
| model = PeftModel.from_pretrained( | |
| model, | |
| adapter_model_name, | |
| token=os.environ.get("HF_AUTH_TOKEN", "YOUR_HUGGING_FACE_TOKEN") | |
| ) | |
| print("Adapter loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading adapter: {str(e)}") | |
| print("Continuing with base model only") | |
| # Load tokenizer | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| adapter_model_name, | |
| token=os.environ.get("HF_AUTH_TOKEN", "YOUR_HUGGING_FACE_TOKEN") | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Model and tokenizer loaded successfully!") | |
| return model, tokenizer | |
| # Define the chatbot function | |
| def chat(message, history, language, max_tokens=100, temperature=0.7, model=None, tokenizer=None): | |
| """ | |
| Generates a response from the chatbot | |
| """ | |
| try: | |
| # Build prompt | |
| prompt = f"Language: {language}\nUser: {message}\nBot:" | |
| # Tokenize with length limitation | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Process response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.split("Bot:")[1].strip() if "Bot:" in response else response[len(prompt):].strip() | |
| # Clean up memory | |
| del inputs, outputs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return response | |
| except Exception as e: | |
| return f"Sorry, I encountered an error: {str(e)}. Please try again." | |
| # Define the translation function | |
| def translate(text, target_language, model, tokenizer): | |
| """ | |
| Translates text to the target language. | |
| """ | |
| try: | |
| prompt = f"Translate this to {target_language}: \"{text}\"\nTranslation:" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| do_sample=True, | |
| temperature=0.3, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| translation = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| translation = translation.split("Translation:")[1].strip() if "Translation:" in translation else translation.strip() | |
| del inputs, outputs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return translation | |
| except Exception as e: | |
| return f"Sorry, I encountered an error: {str(e)}. Please try again." | |
| # Gradio interface with lazy loading | |
| def create_interface(): | |
| # Load model only when needed | |
| model_loaded = False | |
| model, tokenizer = None, None | |
| def load_model_if_needed(): | |
| nonlocal model, tokenizer, model_loaded | |
| if not model_loaded: | |
| try: | |
| model, tokenizer = load_model() | |
| model_loaded = True | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise e | |
| return model, tokenizer | |
| # Main interface function | |
| def chatbot_interface(message, history, language, max_tokens, temperature): | |
| try: | |
| # Load model on first use | |
| model, tokenizer = load_model_if_needed() | |
| # Process the message | |
| if "translate" in message.lower() and "to" in message.lower(): | |
| # Handle translation | |
| parts = message.lower().split("translate")[1].split("to") | |
| if len(parts) >= 2: | |
| text_to_translate = parts[0].strip() | |
| target_lang = parts[1].strip() | |
| translation = translate(text_to_translate, target_lang, model, tokenizer) | |
| response = f"Translation to {target_lang}: {translation}" | |
| else: | |
| response = "Please specify what to translate and target language, e.g., 'translate Hello to Luganda'" | |
| else: | |
| # Regular chat | |
| response = chat(message, history, language, max_tokens, temperature, model, tokenizer) | |
| # Add to history | |
| updated_history = history.copy() | |
| updated_history.append((message, response)) | |
| return updated_history, updated_history | |
| except Exception as e: | |
| error_message = f"Sorry, I encountered an error: {str(e)}. Please try again." | |
| updated_history = history.copy() | |
| updated_history.append((message, error_message)) | |
| return updated_history, updated_history | |
| # Create Gradio interface | |
| with gr.Blocks(title="Alkebulan AI Chatbot") as demo: | |
| gr.Markdown("# Alkebulan AI Chatbot") | |
| gr.Markdown("Chat in Luganda, Iteso, Runyankore, Acholi, or Ateso!") | |
| chat_history = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| language = gr.Dropdown( | |
| label="Select Language", | |
| choices=["Luganda", "Iteso", "Runyankore", "Acholi", "Ateso", "English"], | |
| value="Luganda" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_tokens_slider = gr.Slider(minimum=50, maximum=200, value=100, step=10, label="Max Tokens") | |
| temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature") | |
| gr.Markdown(""" | |
| ## Examples: | |
| - Basic chat: Just type a message in the selected language | |
| - Translate: "translate How are you to Luganda" | |
| """) | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Chat History", height=400) | |
| with gr.Row(): | |
| message = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| scale=8 | |
| ) | |
| submit = gr.Button("Send", scale=1) | |
| clear = gr.Button("Clear", scale=1) | |
| # Set up interactions | |
| submit.click( | |
| chatbot_interface, | |
| inputs=[message, chat_history, language, max_tokens_slider, temperature_slider], | |
| outputs=[chatbot, chat_history] | |
| ) | |
| message.submit( | |
| chatbot_interface, | |
| inputs=[message, chat_history, language, max_tokens_slider, temperature_slider], | |
| outputs=[chatbot, chat_history] | |
| ) | |
| clear.click(lambda: [], None, chatbot, queue=False) | |
| clear.click(lambda: [], None, chat_history, queue=False) | |
| return demo | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=True) # Add share=True to get a public URL |