rohangbs commited on
Commit
970a3e5
·
verified ·
1 Parent(s): d35fc36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -8
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import streamlit as st
 
 
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
 
4
  # Load the fine-tuned model and tokenizer
@@ -6,6 +8,39 @@ model_name = "rohangbs/fine-tuned-gpt2"
6
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
7
  model = GPT2LMHeadModel.from_pretrained(model_name)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Streamlit UI
10
  st.title("Chatbot For Company Details")
11
  st.write("A GPT-2 model fine-tuned for Company dataset.")
@@ -16,14 +51,10 @@ prompt = st.text_area("Ask your question:", height=150)
16
  if st.button("Send"):
17
  if prompt.strip():
18
  with st.spinner("Generating..."):
19
- # Generate text
20
- inputs = tokenizer.encode(prompt, return_tensors="pt")
21
- outputs = model.generate(inputs, max_length=200, num_return_sequences=1)
22
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
-
24
- # Display the response
25
- st.subheader("Generated Response:")
26
- st.write(response)
27
  else:
28
  st.warning("Please enter a prompt.")
29
 
 
1
  import streamlit as st
2
+ import torch
3
+ import re
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
 
6
  # Load the fine-tuned model and tokenizer
 
8
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
9
  model = GPT2LMHeadModel.from_pretrained(model_name)
10
 
11
+ # Ensure the model is on the correct device
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model = model.to(device)
14
+
15
+ # Function to generate a response
16
+ def chat_with_model(input_prompt, max_length=200):
17
+ model.eval()
18
+
19
+ # Format the input prompt with special tokens
20
+ prompt = f"<|startoftext|>[WP] {input_prompt}\n[RESPONSE]"
21
+
22
+ # Tokenize and encode the prompt, and send to the device
23
+ generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
24
+
25
+ # Generate a response
26
+ sample_outputs = model.generate(
27
+ generated,
28
+ do_sample=True,
29
+ top_k=50,
30
+ max_length=max_length,
31
+ top_p=0.95,
32
+ num_return_sequences=1,
33
+ pad_token_id=tokenizer.eos_token_id
34
+ )
35
+
36
+ # Decode the response and clean it up
37
+ response_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
38
+ wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", response_text)[1:]
39
+ clean_responses = [response.strip() for response in wp_responses if response.strip()]
40
+
41
+ # Return the first valid response
42
+ return clean_responses[0] if clean_responses else "I couldn't generate a response."
43
+
44
  # Streamlit UI
45
  st.title("Chatbot For Company Details")
46
  st.write("A GPT-2 model fine-tuned for Company dataset.")
 
51
  if st.button("Send"):
52
  if prompt.strip():
53
  with st.spinner("Generating..."):
54
+ # Generate and display the response
55
+ response = chat_with_model(prompt)
56
+ st.subheader("Generated Response:")
57
+ st.write(response)
 
 
 
 
58
  else:
59
  st.warning("Please enter a prompt.")
60