|
|
import os |
|
|
import gdown |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer |
|
|
import gradio as gr |
|
|
|
|
|
os.makedirs("model", exist_ok=True) |
|
|
|
|
|
MODEL_URL = "https://drive.google.com/uc?id=1Kg8KSGIgjBopeOKSbYbFWEgUlYOcqyXX" |
|
|
MODEL_PATH = "model/model.safetensors" |
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
|
print("⬇ Downloading model weights...") |
|
|
gdown.download(MODEL_URL, MODEL_PATH, quiet=False) |
|
|
else: |
|
|
print("✅ Model file already exists") |
|
|
|
|
|
print("🔧 Loading model & tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("model") |
|
|
model = AutoModelForCausalLM.from_pretrained("model", torch_dtype=torch.float16) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model.to(device) |
|
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
def respond(message, history, max_tokens, temperature, top_p): |
|
|
input_ids = tokenizer.encode(message, return_tensors="pt").to(device) |
|
|
history_text = "" |
|
|
|
|
|
if history: |
|
|
for user, bot in history: |
|
|
history_text += f"<|user|>{user}<|assistant|>{bot}" |
|
|
|
|
|
full_input = history_text + f"<|user|>{message}<|assistant|>" |
|
|
|
|
|
inputs = tokenizer(full_input, return_tensors="pt").to(device) |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
output_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
answer = output_text.split("<|assistant|>")[-1].strip() |
|
|
return answer |
|
|
|
|
|
|
|
|
chat = gr.ChatInterface( |
|
|
fn=respond, |
|
|
additional_inputs=[ |
|
|
gr.Slider(64, 1024, value=256, label="Max Tokens"), |
|
|
gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"), |
|
|
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), |
|
|
], |
|
|
title="🦙 TinyLLaMA Chatbot", |
|
|
description="Fine-tuned TinyLLaMA using QLoRA.", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
chat.launch() |
|
|
|