DSDUDEd commited on
Commit
60cffd8
Β·
verified Β·
1 Parent(s): bea9bd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -37
app.py CHANGED
@@ -1,54 +1,55 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from datetime import datetime
4
 
5
- # Load your custom GPT-2 8-bit model
6
- MODEL_REPO = "DSDUDEd/firebase" # replace with your model repo
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
8
- model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, device_map="auto", load_in_8bit=True)
9
 
10
- # Keep chat history
11
- chat_history = []
 
12
 
13
- def format_message(sender, message):
14
- """Format message with timestamp and sender"""
15
- timestamp = datetime.now().strftime("%H:%M")
16
- return f"{sender} [{timestamp}]: {message}"
17
 
18
- def chat_with_model(user_message):
 
19
  global chat_history
20
-
21
- # Add user message
22
- user_entry = format_message("You", user_message)
23
 
24
- # Encode user input
25
- inputs = tokenizer(user_message, return_tensors="pt")
 
26
 
27
- # Generate AI response
28
- outputs = model.generate(**inputs, max_new_tokens=150)
29
- ai_message_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
- ai_entry = format_message("AI", ai_message_text)
 
 
 
 
 
31
 
32
- # Update chat history
33
- chat_history.append((user_entry, ai_entry))
34
 
35
- # Return chat history as list of tuples for Gradio Chatbot
36
- return chat_history
 
 
 
 
37
 
38
- # Gradio Interface
39
  with gr.Blocks() as demo:
40
  gr.Markdown("## πŸ€– Custom GPT-2 AI Chat")
41
-
42
- chatbot = gr.Chatbot()
43
- msg = gr.Textbox(placeholder="Type your message here...", show_label=False)
44
- clear = gr.Button("Clear Chat")
45
-
46
- def respond(message, history):
47
- output = chat_with_model(message)
48
- return output, ""
49
 
50
- msg.submit(respond, [msg, chatbot], [chatbot, msg])
51
- clear.click(lambda: [], None, chatbot)
52
 
53
- if __name__ == "__main__":
54
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
+ # 1️⃣ Load the model
6
+ MODEL_REPO = "DSDUDEd/firebase" # your HF model repo
7
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
8
+ model = AutoModelForCausalLM.from_pretrained(MODEL_REPO)
9
 
10
+ # Set device (CPU or GPU if available)
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
 
14
+ # 2️⃣ Chat history
15
+ chat_history = []
 
 
16
 
17
+ # 3️⃣ Function to generate AI response
18
+ def chat_with_ai(user_input):
19
  global chat_history
20
+ chat_history.append(f"You: {user_input}")
 
 
21
 
22
+ # Prepare input for the model
23
+ input_text = "\n".join(chat_history) + "\nAI:"
24
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
25
 
26
+ # Generate output
27
+ outputs = model.generate(
28
+ **inputs,
29
+ max_new_tokens=150,
30
+ temperature=0.7,
31
+ top_p=0.9,
32
+ do_sample=True,
33
+ pad_token_id=tokenizer.eos_token_id
34
+ )
35
 
36
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
37
 
38
+ # Extract only the AI's last response
39
+ ai_response = response.split("AI:")[-1].strip()
40
+ chat_history.append(f"AI: {ai_response}")
41
+
42
+ # Display the chat nicely
43
+ return "\n".join(chat_history)
44
 
45
+ # 4️⃣ Gradio interface
46
  with gr.Blocks() as demo:
47
  gr.Markdown("## πŸ€– Custom GPT-2 AI Chat")
48
+ chatbot = gr.Textbox(label="Your Message", placeholder="Type here...", lines=2)
49
+ output = gr.Textbox(label="Chat Output", interactive=False, lines=15)
50
+ send_button = gr.Button("Send")
 
 
 
 
 
51
 
52
+ send_button.click(fn=chat_with_ai, inputs=chatbot, outputs=output)
 
53
 
54
+ # 5️⃣ Launch the Space
55
+ demo.launch()