Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import torch | |
| # Load pre-trained model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("ahmed792002/alzheimers_memory_support_ai") | |
| model = AutoModelForCausalLM.from_pretrained("ahmed792002/alzheimers_memory_support_ai") | |
| # Chatbot function | |
| def chatbot(query, history, system_message, max_length, temperature, top_k, top_p): | |
| """ | |
| Processes a user query through the specified model to generate a response. | |
| """ | |
| # Tokenize input query | |
| input_ids = tokenizer.encode(query, return_tensors="pt") | |
| response = '.' | |
| while response=='.': | |
| # Generate text using the model | |
| final_outputs = model.generate( | |
| input_ids, | |
| do_sample=True, | |
| max_length=int(max_length), # Convert max_length to integer | |
| temperature=float(temperature), # Convert temperature to float | |
| top_k=int(top_k), # Convert top_k to integer | |
| top_p=float(top_p), # Convert top_p to float | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Decode generated text | |
| response = tokenizer.decode(final_outputs[0], skip_special_tokens=True) | |
| response = response.split('"')[1] | |
| return response | |
| # Gradio ChatInterface | |
| demo = gr.ChatInterface( | |
| chatbot, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly chatbot.", label="System message"), | |
| gr.Slider(128, 1024, value=256, step=64, label="Max Length"), # Slider for max_length | |
| gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"), # Slider for temperature | |
| gr.Slider(1, 100, value=50, step=1, label="Top-K"), # Slider for top_k | |
| gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P"), # Slider for top_p | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) # Set `share=True` to create a public link | |