import os from gradio.components import clear_button import torch import gradio as gr import requests from v1.usta_model import UstaModel from v1.usta_tokenizer import UstaTokenizer model, tokenizer, model_status = None, None, "Not Loaded" def load_model(custom_model_path=None): try: u_tokenizer = UstaTokenizer("v1/tokenize.json") print(f"Tokenizer loaded successfully, vocab size: {len(u_tokenizer.vocab)}") context_length = 32 vocab_size = len(u_tokenizer.vocab) embedding_dim = 12 num_heads = 4 num_layers = 8 model = UstaModel( vocab_size=vocab_size, embedding_dim=embedding_dim, num_heads=num_heads, context_length=context_length, num_layers=num_layers) if custom_model_path and os.path.exists(custom_model_path): model.load_state_dict(torch.load(custom_model_path)) else: model.load_state_dict(torch.load("v1/u1_model.pth")) model.eval() print(f"Model loaded successfully, model parameters: {len(u_tokenizer.vocab)}") return model, u_tokenizer, "Model Loaded Successfully" except Exception as e: return None, None, f"Error Loading Model: {e}" try: model, tokenizer, model_status = load_model() except Exception as e: print(f"Error loading model: {e}") model, tokenizer, model_status = None, None, "Error Loading Model" print(f"Model status: {model_status}") if model is not None: print("Model loaded successfully") def chat_with_model(message, chat_history, max_new_tokens = 20): try: tokens = tokenizer.encode(message) if len(tokens) > 25: tokens = tokens[-25:] with torch.no_grad(): actual_max_tokens = min(max_new_tokens,32 - len(tokens)) generated_tokens = model.generate(tokens, max_new_tokens=actual_max_tokens) response = tokenizer.decode(generated_tokens) original_message = tokenizer.decode(tokens.tolist()) if response.startswith(original_message): response = response[len(original_message):] response = response.replace("","").replace("","").strip() print(f"uzunluk {len(response)}") if len(response) <= 0: response = "I am sorry i dont know the answer to that question" chat_history.append((message, response)) return chat_history,"" except Exception as e: print(f"Error generating response {e}") return chat_history, "Error generating response" def load_model_from_url(custom_model_url): global model, tokenizer, model_status try: headers = { "Accept":"application/octet-stream", "User-Agent": "Mozilla5.0 (Windows NT 10.0; Win64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" } response = requests.get(custom_model_url, headers=headers) response.raise_for_status() temp_file = "temp_model_pth" with open(temp_file,"wb") as f: f.write(response.content) model, tokenizer, model_status = load_model(temp_file) os.remove(temp_file) return "Model loaded successfully on url" except Exception as e: print(f"Error loading model from url {e}") return "Error loading model from url" def load_model_from_file(model_file): global model, tokenizer, model_status try: model, tokenizer, model_status=load_model(model_file.name) return " Model loaded on file" except Exception as e: print(f"error loading model on file {e}") return "Error loading model on file" with gr.Blocks(title="Usta Model") as demo: gr.Markdown("# Usta Model") gr.Markdown(" Chat with the model") chatbot = gr.Chatbot(height=300) msg = gr.Textbox(placeholder="Enter your text here...", label="Message") with gr.Row(): send_button = gr.Button("Send", variant="primary") clear_button = gr.Button("Clear") max_new_tokens = gr.Slider( minimum=1, maximum=30, value=20, step=1, label="Max New Tokens", info = "The maximum number of new tokens to generate" ) gr.Markdown("## LOAD CUSTOM MODEL") with gr.Row(): custom_model_url = gr.Textbox( placeholder = "https://github.com/malibayram/llm-from-scratch/raw/refs/heads/main/u_model_4000.pth", label = "Custom Model url", scale = 4 ) load_url_button = gr.Button("Load Model", variant="primary",scale=1) with gr.Row(): model_file = gr.File( label = "Custom Model File", file_types = [".pth", ".pt", ".bin"], ) load_file_button = gr.Button("Load Model", variant="primary") status = gr.Textbox( label = "Model Status", value = model_status, interactive=False, ) def send_message(message, chat_history, max_new_tokens): if not message.strip(): return chat_history, "" return chat_with_model(message, chat_history, max_new_tokens) send_button.click( send_message, inputs = [msg,chatbot,max_new_tokens], outputs=[chatbot,msg] ) msg.submit( send_message, inputs=[msg,chatbot,max_new_tokens], outputs=[chatbot,msg] ) clear_button.click(lambda: None, None, chatbot, status) load_url_button.click( load_model_from_url, inputs=[custom_model_url], outputs=[status] ) load_file_button.click( load_model_from_file, inputs=[model_file], outputs=[status] ) if __name__ == "__main__": demo.launch(share=True)