DSDUDEd commited on
Commit
bea9bd0
Β·
verified Β·
1 Parent(s): f96b59d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -32
app.py CHANGED
@@ -1,42 +1,54 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
 
5
- # Load your fine-tuned model
6
- MODEL_REPO = "DSDUDEd/firebase" # replace with your Hugging Face model repo
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
8
- model = AutoModelForCausalLM.from_pretrained(MODEL_REPO)
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model.to(device)
11
-
12
- def chat_with_ai(user_input, history=[]):
13
- """
14
- Accepts user input and conversation history.
15
- Generates a response and appends it to the chat history.
16
- """
17
- # Combine all previous messages into a single prompt
18
- prompt = ""
19
- for user_msg, ai_msg in history:
20
- prompt += f"User: {user_msg}\nAI: {ai_msg}\n"
21
- prompt += f"User: {user_input}\nAI:"
22
-
23
- # Tokenize and generate
24
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
25
- output_ids = model.generate(input_ids, max_new_tokens=150, pad_token_id=tokenizer.eos_token_id)
26
- response = tokenizer.decode(output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
27
-
28
- # Update history
29
- history.append((user_input, response))
30
- return history, history
31
-
32
- # Gradio Chat Interface
 
 
 
 
 
 
33
  with gr.Blocks() as demo:
34
- gr.Markdown("<h1 style='text-align: center;'>πŸ€– Jenna GPT-2 Chat</h1>")
 
35
  chatbot = gr.Chatbot()
36
- msg = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter")
37
  clear = gr.Button("Clear Chat")
38
 
39
- msg.submit(chat_with_ai, [msg, chatbot], [chatbot, chatbot])
 
 
 
 
40
  clear.click(lambda: [], None, chatbot)
41
 
42
- demo.launch()
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from datetime import datetime
4
 
5
+ # Load your custom GPT-2 8-bit model
6
+ MODEL_REPO = "DSDUDEd/firebase" # replace with your model repo
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
8
+ model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, device_map="auto", load_in_8bit=True)
9
+
10
+ # Keep chat history
11
+ chat_history = []
12
+
13
+ def format_message(sender, message):
14
+ """Format message with timestamp and sender"""
15
+ timestamp = datetime.now().strftime("%H:%M")
16
+ return f"{sender} [{timestamp}]: {message}"
17
+
18
+ def chat_with_model(user_message):
19
+ global chat_history
20
+
21
+ # Add user message
22
+ user_entry = format_message("You", user_message)
23
+
24
+ # Encode user input
25
+ inputs = tokenizer(user_message, return_tensors="pt")
26
+
27
+ # Generate AI response
28
+ outputs = model.generate(**inputs, max_new_tokens=150)
29
+ ai_message_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+ ai_entry = format_message("AI", ai_message_text)
31
+
32
+ # Update chat history
33
+ chat_history.append((user_entry, ai_entry))
34
+
35
+ # Return chat history as list of tuples for Gradio Chatbot
36
+ return chat_history
37
+
38
+ # Gradio Interface
39
  with gr.Blocks() as demo:
40
+ gr.Markdown("## πŸ€– Custom GPT-2 AI Chat")
41
+
42
  chatbot = gr.Chatbot()
43
+ msg = gr.Textbox(placeholder="Type your message here...", show_label=False)
44
  clear = gr.Button("Clear Chat")
45
 
46
+ def respond(message, history):
47
+ output = chat_with_model(message)
48
+ return output, ""
49
+
50
+ msg.submit(respond, [msg, chatbot], [chatbot, msg])
51
  clear.click(lambda: [], None, chatbot)
52
 
53
+ if __name__ == "__main__":
54
+ demo.launch()