Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import time | |
| import spaces | |
| import re | |
| # Model configurations | |
| MODELS = { | |
| "Athena-R3X 8B": "Spestly/Athena-R3X-8B", | |
| "Athena-R3X 4B": "Spestly/Athena-R3X-4B", | |
| "Athena-R3 7B": "Spestly/Athena-R3-7B", | |
| "Athena-3 3B": "Spestly/Athena-3-3B", | |
| "Athena-3 7B": "Spestly/Athena-3-7B", | |
| "Athena-3 14B": "Spestly/Athena-3-14B", | |
| "Athena-2 1.5B": "Spestly/Athena-2-1.5B", | |
| "Athena-1 3B": "Spestly/Athena-1-3B", | |
| "Athena-1 7B": "Spestly/Athena-1-7B" | |
| } | |
| def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): | |
| """Generate response using ZeroGPU - all CUDA operations happen here""" | |
| print(f"π Loading {model_id}...") | |
| start_time = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| load_time = time.time() - start_time | |
| print(f"β Model loaded in {load_time:.2f}s") | |
| # Build messages in proper chat format (OpenAI-style messages) | |
| messages = [] | |
| system_prompt = ( | |
| "You are Athena, a helpful, harmless, and honest AI assistant. " | |
| "You provide clear, accurate, and concise responses to user questions. " | |
| "You are knowledgeable across many domains and always aim to be respectful and helpful. " | |
| "You are finetuned by Aayan Mishra" | |
| ) | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Add conversation history | |
| for msg in conversation: | |
| messages.append(msg) | |
| # Add current user message | |
| messages.append({"role": "user", "content": user_message}) | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| generation_start = time.time() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| generation_time = time.time() - generation_start | |
| response = tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[-1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| print(f"Generation time: {generation_time:.2f}s") | |
| return response, load_time, generation_time | |
| def format_response_with_thinking(response): | |
| """Format response to handle <think></think> tags""" | |
| # Check if response contains thinking tags | |
| if '<think>' in response and '</think>' in response: | |
| # Split the response into parts | |
| pattern = r'(.*?)(<think>(.*?)</think>)(.*)' | |
| match = re.search(pattern, response, re.DOTALL) | |
| if match: | |
| before_thinking = match.group(1).strip() | |
| thinking_content = match.group(3).strip() | |
| after_thinking = match.group(4).strip() | |
| # Create HTML with collapsible thinking section | |
| html = f"{before_thinking}\n" | |
| html += f'<div class="thinking-container">' | |
| html += f'<button class="thinking-toggle" onclick="this.nextElementSibling.classList.toggle(\'hidden\'); this.textContent = this.textContent === \'Show reasoning\' ? \'Hide reasoning\' : \'Show reasoning\'">Show reasoning</button>' | |
| html += f'<div class="thinking-content hidden">{thinking_content}</div>' | |
| html += f'</div>\n' | |
| html += after_thinking | |
| return html | |
| # If no thinking tags, return the original response | |
| return response | |
| def chat_submit(message, history, conversation_state, model_name, max_length, temperature): | |
| """Process a new message and update the chat history""" | |
| if not message.strip(): | |
| return "", history, conversation_state | |
| model_id = MODELS.get(model_name, MODELS["Athena-R3X 4B"]) | |
| try: | |
| # Print debug info to help diagnose issues | |
| print(f"Processing message: {message}") | |
| print(f"Selected model: {model_name} ({model_id})") | |
| response, load_time, generation_time = generate_response( | |
| model_id, conversation_state, message, max_length, temperature | |
| ) | |
| # Update the conversation state with the raw response | |
| conversation_state.append({"role": "user", "content": message}) | |
| conversation_state.append({"role": "assistant", "content": response}) | |
| # Format the response for display | |
| formatted_response = format_response_with_thinking(response) | |
| # Update the visible chat history | |
| history.append((message, formatted_response)) | |
| print(f"Response added to history. Current length: {len(history)}") | |
| return "", history, conversation_state | |
| except Exception as e: | |
| import traceback | |
| print(f"Error in chat_submit: {str(e)}") | |
| print(traceback.format_exc()) | |
| error_message = f"Error: {str(e)}" | |
| history.append((message, error_message)) | |
| return "", history, conversation_state | |
| css = """ | |
| .message { | |
| padding: 10px; | |
| margin: 5px; | |
| border-radius: 10px; | |
| } | |
| .thinking-container { | |
| margin: 10px 0; | |
| } | |
| .thinking-toggle { | |
| background-color: #f1f1f1; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| padding: 5px 10px; | |
| cursor: pointer; | |
| font-size: 0.9em; | |
| margin-bottom: 5px; | |
| color: #555; | |
| } | |
| .thinking-content { | |
| background-color: #f9f9f9; | |
| border-left: 3px solid #ccc; | |
| padding: 10px; | |
| margin-top: 5px; | |
| font-size: 0.95em; | |
| color: #555; | |
| font-family: monospace; | |
| white-space: pre-wrap; | |
| overflow-x: auto; | |
| } | |
| .hidden { | |
| display: none; | |
| } | |
| """ | |
| with gr.Blocks(title="Athena Playground Chat", css=css, theme='NoCrypt/miku') as demo: | |
| gr.Markdown("# π Athena Playground Chat") | |
| gr.Markdown("*Powered by HuggingFace ZeroGPU*") | |
| # State to keep track of the conversation for the model | |
| conversation_state = gr.State([]) | |
| chatbot = gr.Chatbot(height=500, label="Athena", render_markdown=True) | |
| with gr.Row(): | |
| user_input = gr.Textbox(label="Your message", scale=8, autofocus=True, placeholder="Type your message here...") | |
| send_btn = gr.Button(value="Send", scale=1, variant="primary") | |
| # Clear button for resetting the conversation | |
| clear_btn = gr.Button("Clear Conversation") | |
| # Configuration controls | |
| gr.Markdown("### βοΈ Model & Generation Settings") | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| label="π± Model", | |
| choices=list(MODELS.keys()), | |
| value="Athena-R3X 4B", | |
| info="Select which Athena model to use" | |
| ) | |
| max_length = gr.Slider( | |
| 32, 8192, value=512, | |
| label="π Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| temperature = gr.Slider( | |
| 0.1, 2.0, value=0.7, | |
| label="π¨ Creativity", | |
| info="Higher values = more creative responses" | |
| ) | |
| # Function to clear the conversation | |
| def clear_conversation(): | |
| return [], [] | |
| # Connect the interface components - note the specific ordering | |
| user_input.submit( | |
| chat_submit, | |
| inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], | |
| outputs=[user_input, chatbot, conversation_state] | |
| ) | |
| # Make sure send button uses the exact same function with the same parameter ordering | |
| send_btn.click( | |
| chat_submit, | |
| inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], | |
| outputs=[user_input, chatbot, conversation_state] | |
| ) | |
| # Connect clear button | |
| clear_btn.click(clear_conversation, outputs=[chatbot, conversation_state]) | |
| # Add examples if desired | |
| gr.Examples( | |
| examples=[ | |
| "What is artificial intelligence?", | |
| "Can you explain quantum computing?", | |
| "Write a short poem about technology", | |
| "What are some ethical concerns about AI?" | |
| ], | |
| inputs=[user_input] | |
| ) | |
| gr.Markdown(""" | |
| ### About the Thinking Tags | |
| Some Athena models (particularly R3X series) include reasoning in `<think></think>` tags. | |
| Click "Show reasoning" to see the model's thought process behind its answers. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) # Enable debug mode for better error reporting |