FORMATPOTTER commited on
Commit
fa4631c
·
verified ·
1 Parent(s): ecae831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -21
app.py CHANGED
@@ -1,30 +1,23 @@
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
 
4
- # Path to the uploaded model folder
5
- model_path = "model"
6
-
7
- # Load model and tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained(model_path)
9
  model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
 
10
 
11
- # Create text2text-generation pipeline
12
- generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
13
-
14
- # Define Gradio interface function
15
- def ask_question(question):
16
- result = generator(question, max_new_tokens=100)
17
- return result[0]['generated_text']
18
 
19
- # Gradio interface
20
- iface = gr.Interface(
21
- fn=ask_question,
22
- inputs=gr.Textbox(lines=2, placeholder="Ask anything..."),
23
- outputs=gr.Textbox(),
24
- title="FPV2 Chatbot",
25
- description="Ask questions and get answers from the FPV2 model."
26
- )
27
 
28
- # Launch the interface
29
- iface.launch(server_name="0.0.0.0", server_port=7860)
30
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
 
4
+ # Load your model
5
+ model_path = "./model" # your uploaded folder
 
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_path)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
8
+ chat_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
9
 
10
+ # Function to handle chat
11
+ def chat_with_bot(history, user_message):
12
+ response = chat_pipeline(user_message)[0]['generated_text']
13
+ history.append((user_message, response))
14
+ return history, ""
 
 
15
 
16
+ # Gradio UI
17
+ with gr.Blocks() as demo:
18
+ chatbot = gr.Chatbot()
19
+ msg = gr.Textbox(placeholder="Type your message here...")
20
+ msg.submit(chat_with_bot, [chatbot, msg], [chatbot, msg])
 
 
 
21
 
22
+ demo.launch(server_name="0.0.0.0", server_port=7860)
 
23