FadQ commited on
Commit
b873662
·
verified ·
1 Parent(s): bc8151e
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -1,22 +1,22 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
- import os
4
 
5
  # Gunakan model yang sudah diunggah ke Hugging Face
6
  model_path = "FadQ/gemma-2b-diary-consultaton-chatbot"
7
 
8
- # Load the base model and tokenizer with trust_remote_code=True
9
- model = AutoModelForCausalLM.from_pretrained(base_model, trust_remote_code=True)
10
- tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
11
 
12
- # Ensure pipeline uses the correct tokenizer and check device
13
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0) # Ubah ke device=-1 jika pakai CPU
14
 
15
  def predict(input_text):
16
  result = pipe(input_text, max_length=150, num_return_sequences=1)
17
  return result[0]["generated_text"]
18
 
19
- # Create the Gradio interface
20
  demo = gr.Interface(fn=predict, inputs=gr.Textbox(label="Input Text"), outputs="text")
21
 
22
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+ import torch
4
 
5
  # Gunakan model yang sudah diunggah ke Hugging Face
6
  model_path = "FadQ/gemma-2b-diary-consultaton-chatbot"
7
 
8
+ # Load model dan tokenizer dengan `trust_remote_code=True`
9
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
10
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
11
 
12
+ # Buat pipeline dengan tokenizer yang sesuai
13
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
14
 
15
  def predict(input_text):
16
  result = pipe(input_text, max_length=150, num_return_sequences=1)
17
  return result[0]["generated_text"]
18
 
19
+ # Buat antarmuka Gradio
20
  demo = gr.Interface(fn=predict, inputs=gr.Textbox(label="Input Text"), outputs="text")
21
 
22
+ demo.launch()