LLM_from / app.py
USER
app is complete
b1dc905
import os
from gradio.components import clear_button
import torch
import gradio as gr
import requests
from v1.usta_model import UstaModel
from v1.usta_tokenizer import UstaTokenizer
model, tokenizer, model_status = None, None, "Not Loaded"
def load_model(custom_model_path=None):
try:
u_tokenizer = UstaTokenizer("v1/tokenize.json")
print(f"Tokenizer loaded successfully, vocab size: {len(u_tokenizer.vocab)}")
context_length = 32
vocab_size = len(u_tokenizer.vocab)
embedding_dim = 12
num_heads = 4
num_layers = 8
model = UstaModel(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
num_heads=num_heads,
context_length=context_length,
num_layers=num_layers)
if custom_model_path and os.path.exists(custom_model_path):
model.load_state_dict(torch.load(custom_model_path))
else:
model.load_state_dict(torch.load("v1/u1_model.pth"))
model.eval()
print(f"Model loaded successfully, model parameters: {len(u_tokenizer.vocab)}")
return model, u_tokenizer, "Model Loaded Successfully"
except Exception as e:
return None, None, f"Error Loading Model: {e}"
try:
model, tokenizer, model_status = load_model()
except Exception as e:
print(f"Error loading model: {e}")
model, tokenizer, model_status = None, None, "Error Loading Model"
print(f"Model status: {model_status}")
if model is not None:
print("Model loaded successfully")
def chat_with_model(message, chat_history, max_new_tokens = 20):
try:
tokens = tokenizer.encode(message)
if len(tokens) > 25:
tokens = tokens[-25:]
with torch.no_grad():
actual_max_tokens = min(max_new_tokens,32 - len(tokens))
generated_tokens = model.generate(tokens, max_new_tokens=actual_max_tokens)
response = tokenizer.decode(generated_tokens)
original_message = tokenizer.decode(tokens.tolist())
if response.startswith(original_message):
response = response[len(original_message):]
response = response.replace("<pad>","").replace("<unk>","").strip()
print(f"uzunluk {len(response)}")
if len(response) <= 0:
response = "I am sorry i dont know the answer to that question"
chat_history.append((message, response))
return chat_history,""
except Exception as e:
print(f"Error generating response {e}")
return chat_history, "Error generating response"
def load_model_from_url(custom_model_url):
global model, tokenizer, model_status
try:
headers = {
"Accept":"application/octet-stream",
"User-Agent": "Mozilla5.0 (Windows NT 10.0; Win64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
}
response = requests.get(custom_model_url, headers=headers)
response.raise_for_status()
temp_file = "temp_model_pth"
with open(temp_file,"wb") as f:
f.write(response.content)
model, tokenizer, model_status = load_model(temp_file)
os.remove(temp_file)
return "Model loaded successfully on url"
except Exception as e:
print(f"Error loading model from url {e}")
return "Error loading model from url"
def load_model_from_file(model_file):
global model, tokenizer, model_status
try:
model, tokenizer, model_status=load_model(model_file.name)
return " Model loaded on file"
except Exception as e:
print(f"error loading model on file {e}")
return "Error loading model on file"
with gr.Blocks(title="Usta Model") as demo:
gr.Markdown("# Usta Model")
gr.Markdown(" Chat with the model")
chatbot = gr.Chatbot(height=300)
msg = gr.Textbox(placeholder="Enter your text here...", label="Message")
with gr.Row():
send_button = gr.Button("Send", variant="primary")
clear_button = gr.Button("Clear")
max_new_tokens = gr.Slider(
minimum=1,
maximum=30,
value=20,
step=1,
label="Max New Tokens",
info = "The maximum number of new tokens to generate"
)
gr.Markdown("## LOAD CUSTOM MODEL")
with gr.Row():
custom_model_url = gr.Textbox(
placeholder = "https://github.com/malibayram/llm-from-scratch/raw/refs/heads/main/u_model_4000.pth",
label = "Custom Model url",
scale = 4
)
load_url_button = gr.Button("Load Model", variant="primary",scale=1)
with gr.Row():
model_file = gr.File(
label = "Custom Model File",
file_types = [".pth", ".pt", ".bin"],
)
load_file_button = gr.Button("Load Model", variant="primary")
status = gr.Textbox(
label = "Model Status",
value = model_status,
interactive=False,
)
def send_message(message, chat_history, max_new_tokens):
if not message.strip():
return chat_history, ""
return chat_with_model(message, chat_history, max_new_tokens)
send_button.click(
send_message,
inputs = [msg,chatbot,max_new_tokens],
outputs=[chatbot,msg]
)
msg.submit(
send_message,
inputs=[msg,chatbot,max_new_tokens],
outputs=[chatbot,msg]
)
clear_button.click(lambda: None, None, chatbot, status)
load_url_button.click(
load_model_from_url,
inputs=[custom_model_url],
outputs=[status]
)
load_file_button.click(
load_model_from_file,
inputs=[model_file],
outputs=[status]
)
if __name__ == "__main__":
demo.launch(share=True)