File size: 1,654 Bytes
e19dcac
3eefc67
3a2f449
b873662
3a2f449
 
 
 
 
 
95a7c9d
3a2f449
 
dc2bc22
 
 
 
 
80ac91b
dc2bc22
 
 
 
80ac91b
e19dcac
dc2bc22
 
 
 
80ac91b
 
dc2bc22
80ac91b
95a7c9d
3a2f449
 
7b91d9a
3a2f449
b873662
e19dcac
27564ff
3a2f449
dc2bc22
80ac91b
3a2f449
27564ff
3a2f449
 
 
 
 
 
e19dcac
3a2f449
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import torch
import os

# Pastikan Hugging Face Token disediakan (jika private repo)
hf_token = os.getenv('HF_TOKEN')

# Path model dasar dan adapter
base_model = "google/gemma-2b-it"
adapter_model = "FadQ/gemma-2b-diary-consultaton-chatbot"

# Pastikan menggunakan versi terbaru untuk kompatibilitas
import subprocess
subprocess.run(["pip", "install", "--upgrade", "peft", "transformers", "accelerate"])

# Load model dasar dengan memastikan tidak dalam mode meta tensor
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype=torch.float16,  
    device_map="auto",
    low_cpu_mem_usage=True  # Pastikan model benar-benar dimuat ke memori
)

# Pastikan semua weight telah dimuat sebelum apply adapter
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# Load adapter PEFT setelah model utama benar-benar dimuat
model = PeftModel.from_pretrained(
    model, 
    adapter_model
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)

# Create pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)

def predict(input_text):
    inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
    with torch.no_grad():
        output = model.generate(**inputs, max_length=150)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Create Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(label="Input Text"),
    outputs=gr.Textbox(label="Generated Response")
)

if __name__ == "__main__":
    demo.launch()