Spaces:
Paused
Paused
| import gradio as gr | |
| import plotly.express as px | |
| import os | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BlenderbotForConditionalGeneration | |
| # Check if CUDA is available and set device accordingly | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Set environment variables for GPU usage and memory allocation if CUDA is available | |
| if device == "cuda": | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
| torch.cuda.empty_cache() | |
| torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed | |
| # System message (placeholder, adjust as needed) | |
| system_message = "" | |
| # Load the model and tokenizer | |
| def hermes_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ") | |
| model = AutoModelForCausalLM.from_pretrained("TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ", low_cpu_mem_usage=True, device_map="auto") | |
| return model, tokenizer | |
| def blender_model(): | |
| model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") | |
| return model, tokenizer | |
| model, tokenizer = blender_model() | |
| def chat_response(msg_prompt: str) -> str: | |
| try: | |
| inputs = tokenizer(msg_prompt, return_tensors="pt") | |
| reply_ids = model.generate(**inputs) | |
| outputs = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0] | |
| return outputs | |
| except Exception as e: | |
| return str(e) | |
| # Function to generate a response from the model | |
| def chat_responses(msg_prompt: str) -> str: | |
| """ | |
| Generates a response from the model given a prompt. | |
| Args: | |
| msg_prompt (str): The user's message prompt. | |
| Returns: | |
| str: The model's response. | |
| """ | |
| generation_params = { | |
| "do_sample": True, | |
| "temperature": 0.7, | |
| "top_p": 0.95, | |
| "top_k": 40, | |
| "max_new_tokens": 512, | |
| "repetition_penalty": 1.1, | |
| } | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **generation_params) | |
| try: | |
| prompt_template = f'''system | |
| {system_message} | |
| user | |
| {msg_prompt} | |
| assistant | |
| ''' | |
| pipe_output = pipe(prompt_template)[0]['generated_text'] | |
| # Separate assistant's response from the output | |
| response_lines = pipe_output.split('assistant') | |
| assistant_response = response_lines[-1].strip() if len(response_lines) > 1 else pipe_output.strip() | |
| return assistant_response | |
| except Exception as e: | |
| return str(e) | |
| # Function to generate a random plot | |
| def random_plot(): | |
| df = px.data.iris() | |
| fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species", | |
| size='petal_length', hover_data=['petal_width']) | |
| return fig | |
| # Function to handle likes/dislikes (for demonstration purposes) | |
| def print_like_dislike(x: gr.LikeData): | |
| print(x.index, x.value, x.liked) | |
| # Function to add messages to the chat history | |
| def add_message(history, message, files): | |
| if files is not None: | |
| for file in files: | |
| history.append(((file,), None)) | |
| if message is not None: | |
| history.append((message, None)) | |
| return history, gr.update(value=None, interactive=True) | |
| # Function to simulate the bot response | |
| def bot(history): | |
| if history: | |
| user_message = history[-1][0] | |
| bot_response = chat_response(user_message) | |
| history[-1][1] = bot_response | |
| return history | |
| fig = random_plot() | |
| # Gradio interface setup | |
| with gr.Blocks(fill_height=True) as demo: | |
| chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, scale=1) | |
| with gr.Row(): | |
| chat_input = gr.Textbox(placeholder="Enter message...", show_label=False) | |
| file_input = gr.File(label="Upload file(s)", file_count="multiple") | |
| chat_msg = chat_input.submit(add_message, [chatbot, chat_input, file_input], [chatbot, chat_input]) | |
| bot_msg = chat_msg.then(bot, chatbot, chatbot) | |
| bot_msg.then(lambda: gr.update(interactive=True), None, [chat_input]) | |
| chatbot.like(print_like_dislike, None, None) | |
| demo.queue() | |
| demo.launch() |