Mehak-Mazhar commited on
Commit
ac982af
·
verified ·
1 Parent(s): a87f9a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -61
app.py CHANGED
@@ -2,71 +2,52 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Load a lightweight, conversational model
6
- model_name = "microsoft/DialoGPT-small"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- # Store chat history
11
- chat_history_ids = None
12
-
13
- # Inference function
14
- def chat_with_bot(user_input, history=[]):
15
- global chat_history_ids
16
- new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
17
-
18
- if chat_history_ids is not None:
19
- bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
20
- else:
21
- bot_input_ids = new_input_ids
22
-
23
- chat_history_ids = model.generate(
24
- bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id,
25
- no_repeat_ngram_size=3, do_sample=True, top_k=100, top_p=0.7, temperature=0.8
26
- )
27
-
28
- response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
29
- history.append((user_input, response))
30
- return history, history
31
-
32
- # Gradio Interface
33
- def reset_history():
34
- global chat_history_ids
35
- chat_history_ids = None
36
- return []
37
-
38
- theme_css = """
39
- body {
40
- background-color: #FFF3E0;
41
- font-family: Arial, sans-serif;
42
- }
43
-
44
- h1 {
45
- color: darkorange;
46
- text-align: center;
47
- font-weight: bold;
48
- }
49
-
50
- .gradio-container {
51
- max-width: 900px;
52
- margin: auto;
53
- padding: 20px;
54
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2);
55
- border-radius: 12px;
56
- background-color: white;
57
- }
58
- """
59
-
60
- with gr.Blocks(css=theme_css) as demo:
61
- gr.Markdown("""<h1>🤖 AI Chatbot Assistant</h1>""")
62
-
63
- chatbot = gr.Chatbot()
64
- msg = gr.Textbox(placeholder="Type your message here...", label="Your Message")
65
- clear = gr.Button("🧹 Clear Chat")
66
 
67
  state = gr.State([])
68
 
69
- msg.submit(chat_with_bot, [msg, state], [chatbot, state])
70
- clear.click(fn=reset_history, outputs=[chatbot, state])
71
 
72
- demo.launch()
 
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Load a lightweight model
6
+ model_name = "mrm8488/GPT-2-finetuned-on-conversational-data" # ~500MB
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+
13
+ # Response generation logic
14
+ def generate_response(message, chat_history):
15
+ try:
16
+ chat_history = chat_history or []
17
+ full_prompt = ""
18
+ for turn in chat_history:
19
+ full_prompt += f"User: {turn['content']}\n"
20
+ full_prompt += f"AI: {turn.get('response', '')}\n"
21
+ full_prompt += f"User: {message}\nAI:"
22
+
23
+ inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
24
+ outputs = model.generate(inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id)
25
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+
27
+ # Extract only the latest reply after "AI:"
28
+ response = decoded.split("AI:")[-1].strip()
29
+
30
+ # Add to chat history
31
+ chat_history.append({"role": "user", "content": message})
32
+ chat_history.append({"role": "assistant", "content": response})
33
+ return chat_history
34
+ except Exception as e:
35
+ return chat_history + [{"role": "assistant", "content": f"⚠️ Error: {str(e)}"}]
36
+
37
+ # Gradio interface
38
+ with gr.Blocks(css="body { background-color: #FFF9C4; }") as demo:
39
+ gr.Markdown("<h1 style='color:brown; font-weight:bold;'>🧠 AI Chatbot</h1>")
40
+ chatbot = gr.Chatbot(label="Talk to AI", type="messages")
41
+ msg = gr.Textbox(label="Ask something...", placeholder="Type your message here...")
42
+
43
+ clear = gr.Button("Clear Chat")
44
+ footer = gr.Markdown("<p style='text-align: right; color: brown;'>Designed by Mehak Mazhar</p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  state = gr.State([])
47
 
48
+ msg.submit(generate_response, [msg, state], [chatbot, state])
49
+ clear.click(lambda: ([], []), None, [chatbot, state])
50
 
51
+ # Run
52
+ if __name__ == "__main__":
53
+ demo.launch()