Spaces:
Sleeping
Sleeping
| import os | |
| from gradio.components import clear_button | |
| import torch | |
| import gradio as gr | |
| import requests | |
| from v1.usta_model import UstaModel | |
| from v1.usta_tokenizer import UstaTokenizer | |
| model, tokenizer, model_status = None, None, "Not Loaded" | |
| def load_model(custom_model_path=None): | |
| try: | |
| u_tokenizer = UstaTokenizer("v1/tokenize.json") | |
| print(f"Tokenizer loaded successfully, vocab size: {len(u_tokenizer.vocab)}") | |
| context_length = 32 | |
| vocab_size = len(u_tokenizer.vocab) | |
| embedding_dim = 12 | |
| num_heads = 4 | |
| num_layers = 8 | |
| model = UstaModel( | |
| vocab_size=vocab_size, | |
| embedding_dim=embedding_dim, | |
| num_heads=num_heads, | |
| context_length=context_length, | |
| num_layers=num_layers) | |
| if custom_model_path and os.path.exists(custom_model_path): | |
| model.load_state_dict(torch.load(custom_model_path)) | |
| else: | |
| model.load_state_dict(torch.load("v1/u1_model.pth")) | |
| model.eval() | |
| print(f"Model loaded successfully, model parameters: {len(u_tokenizer.vocab)}") | |
| return model, u_tokenizer, "Model Loaded Successfully" | |
| except Exception as e: | |
| return None, None, f"Error Loading Model: {e}" | |
| try: | |
| model, tokenizer, model_status = load_model() | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model, tokenizer, model_status = None, None, "Error Loading Model" | |
| print(f"Model status: {model_status}") | |
| if model is not None: | |
| print("Model loaded successfully") | |
| def chat_with_model(message, chat_history, max_new_tokens = 20): | |
| try: | |
| tokens = tokenizer.encode(message) | |
| if len(tokens) > 25: | |
| tokens = tokens[-25:] | |
| with torch.no_grad(): | |
| actual_max_tokens = min(max_new_tokens,32 - len(tokens)) | |
| generated_tokens = model.generate(tokens, max_new_tokens=actual_max_tokens) | |
| response = tokenizer.decode(generated_tokens) | |
| original_message = tokenizer.decode(tokens.tolist()) | |
| if response.startswith(original_message): | |
| response = response[len(original_message):] | |
| response = response.replace("<pad>","").replace("<unk>","").strip() | |
| print(f"uzunluk {len(response)}") | |
| if len(response) <= 0: | |
| response = "I am sorry i dont know the answer to that question" | |
| chat_history.append((message, response)) | |
| return chat_history,"" | |
| except Exception as e: | |
| print(f"Error generating response {e}") | |
| return chat_history, "Error generating response" | |
| def load_model_from_url(custom_model_url): | |
| global model, tokenizer, model_status | |
| try: | |
| headers = { | |
| "Accept":"application/octet-stream", | |
| "User-Agent": "Mozilla5.0 (Windows NT 10.0; Win64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" | |
| } | |
| response = requests.get(custom_model_url, headers=headers) | |
| response.raise_for_status() | |
| temp_file = "temp_model_pth" | |
| with open(temp_file,"wb") as f: | |
| f.write(response.content) | |
| model, tokenizer, model_status = load_model(temp_file) | |
| os.remove(temp_file) | |
| return "Model loaded successfully on url" | |
| except Exception as e: | |
| print(f"Error loading model from url {e}") | |
| return "Error loading model from url" | |
| def load_model_from_file(model_file): | |
| global model, tokenizer, model_status | |
| try: | |
| model, tokenizer, model_status=load_model(model_file.name) | |
| return " Model loaded on file" | |
| except Exception as e: | |
| print(f"error loading model on file {e}") | |
| return "Error loading model on file" | |
| with gr.Blocks(title="Usta Model") as demo: | |
| gr.Markdown("# Usta Model") | |
| gr.Markdown(" Chat with the model") | |
| chatbot = gr.Chatbot(height=300) | |
| msg = gr.Textbox(placeholder="Enter your text here...", label="Message") | |
| with gr.Row(): | |
| send_button = gr.Button("Send", variant="primary") | |
| clear_button = gr.Button("Clear") | |
| max_new_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| value=20, | |
| step=1, | |
| label="Max New Tokens", | |
| info = "The maximum number of new tokens to generate" | |
| ) | |
| gr.Markdown("## LOAD CUSTOM MODEL") | |
| with gr.Row(): | |
| custom_model_url = gr.Textbox( | |
| placeholder = "https://github.com/malibayram/llm-from-scratch/raw/refs/heads/main/u_model_4000.pth", | |
| label = "Custom Model url", | |
| scale = 4 | |
| ) | |
| load_url_button = gr.Button("Load Model", variant="primary",scale=1) | |
| with gr.Row(): | |
| model_file = gr.File( | |
| label = "Custom Model File", | |
| file_types = [".pth", ".pt", ".bin"], | |
| ) | |
| load_file_button = gr.Button("Load Model", variant="primary") | |
| status = gr.Textbox( | |
| label = "Model Status", | |
| value = model_status, | |
| interactive=False, | |
| ) | |
| def send_message(message, chat_history, max_new_tokens): | |
| if not message.strip(): | |
| return chat_history, "" | |
| return chat_with_model(message, chat_history, max_new_tokens) | |
| send_button.click( | |
| send_message, | |
| inputs = [msg,chatbot,max_new_tokens], | |
| outputs=[chatbot,msg] | |
| ) | |
| msg.submit( | |
| send_message, | |
| inputs=[msg,chatbot,max_new_tokens], | |
| outputs=[chatbot,msg] | |
| ) | |
| clear_button.click(lambda: None, None, chatbot, status) | |
| load_url_button.click( | |
| load_model_from_url, | |
| inputs=[custom_model_url], | |
| outputs=[status] | |
| ) | |
| load_file_button.click( | |
| load_model_from_file, | |
| inputs=[model_file], | |
| outputs=[status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |