Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # --- Configuration --- | |
| MODEL_NAME = "NorwAI/NorwAI-Llama2-7B" #"google/gemma-2-9b" | |
| # --- Model Loading (Explicit) --- | |
| # Use a try-except block to handle potential loading errors | |
| try: | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Load the model with appropriate configurations. | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", # Use "auto" to let Transformers handle device placement. | |
| torch_dtype=torch.bfloat16, # Use bfloat16 for reduced memory usage (if supported by your hardware). | |
| ) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # You might want to raise the exception or exit gracefully here. | |
| raise | |
| # --- Inference Function --- | |
| def respond(message, history, system_message, max_tokens, temperature, top_p): | |
| try: | |
| # Build the conversation history. Use the correct roles ("user", "model"). | |
| formatted_history = "" | |
| for user_msg, model_msg in history: | |
| formatted_history += f"<start_of_turn>user\n{user_msg}<end_of_turn>\n" | |
| if model_msg: # Check if model_msg is not None | |
| formatted_history += f"<start_of_turn>model\n{model_msg}<end_of_turn>\n" | |
| # Combine system message, history, and current message. | |
| prompt = f"<start_of_turn>system\n{system_message}<end_of_turn>\n{formatted_history}<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n" | |
| # Tokenize the input | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Generate text with streaming (important for a chatbot). | |
| streamer = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, # Enable sampling for more diverse responses. | |
| streamer=True, #for stream | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Accumulate the response. Decode in chunks. | |
| response = "" | |
| for chunk in streamer: | |
| if chunk is not None: | |
| response += tokenizer.decode(chunk[0], skip_special_tokens=True) | |
| yield response | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| yield "An error occurred during generation." | |
| return | |
| # --- Gradio Interface --- | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |