File size: 2,069 Bytes
aa5931c d44c8ed aa5931c 2431f8a aa5931c 2431f8a aa5931c 2431f8a aa5931c 2431f8a aa5931c 2431f8a aa5931c 2431f8a aa5931c 2431f8a aa5931c 2431f8a aa5931c 2431f8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import gdown
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import gradio as gr
os.makedirs("model", exist_ok=True)
MODEL_URL = "https://drive.google.com/uc?id=1Kg8KSGIgjBopeOKSbYbFWEgUlYOcqyXX" # <- Ganti file ID-nya
MODEL_PATH = "model/model.safetensors"
if not os.path.exists(MODEL_PATH):
print("⬇ Downloading model weights...")
gdown.download(MODEL_URL, MODEL_PATH, quiet=False)
else:
print("✅ Model file already exists")
print("🔧 Loading model & tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("model")
model = AutoModelForCausalLM.from_pretrained("model", torch_dtype=torch.float16)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
def respond(message, history, max_tokens, temperature, top_p):
input_ids = tokenizer.encode(message, return_tensors="pt").to(device)
history_text = ""
if history:
for user, bot in history:
history_text += f"<|user|>{user}<|assistant|>{bot}"
full_input = history_text + f"<|user|>{message}<|assistant|>"
inputs = tokenizer(full_input, return_tensors="pt").to(device)
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id
)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Ambil jawaban terakhir saja
answer = output_text.split("<|assistant|>")[-1].strip()
return answer
# ==== STEP 4: Gradio UI ====
chat = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Slider(64, 1024, value=256, label="Max Tokens"),
gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
],
title="🦙 TinyLLaMA Chatbot",
description="Fine-tuned TinyLLaMA using QLoRA.",
)
if __name__ == "__main__":
chat.launch()
|