testSpace / app.py
musaashaikh's picture
Update app.py
87e720f verified
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
# Load DialoGPT model and tokenizer
model_name = "microsoft/DialoGPT-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Load a sentiment analysis pipeline
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
# Initialize conversation history
chat_history_ids = None
# Chat loop
print("Chatbot: Hi! I'm DialoGPT. Let's chat about anything, including movies! (type 'quit' to exit)")
while True:
try:
# Get user input
user_input = input("You: ")
except EOFError:
# Handle EOFError by setting a default input and continuing
print("\nChatbot: I noticed an issue with input, but let's continue!")
user_input = "quit" # Default to quit if EOFError occurs
# Exit the loop if user types 'quit'
if user_input.lower() == "quit":
print("Chatbot: Goodbye! Have a great day!")
break
# Check if the input is a movie review query
if "movie" in user_input.lower() or "film" in user_input.lower():
# Analyze the sentiment of the user input
sentiment = sentiment_analyzer(user_input)[0]
sentiment_label = sentiment["label"]
sentiment_score = sentiment["score"]
if sentiment_label == "POSITIVE":
response = f"Sounds like you really enjoyed the movie! I'm glad to hear that. 😊 (Confidence: {sentiment_score:.2f})"
else:
response = f"I'm sorry to hear you didn't enjoy the movie. 😞 (Confidence: {sentiment_score:.2f})"
print(f"Chatbot: {response}")
continue
# Encode user input and add conversation history
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
chat_history_ids = (
torch.cat([chat_history_ids, input_ids], dim=-1) if chat_history_ids is not None else input_ids
)
# Generate a response
response_ids = model.generate(
chat_history_ids,
max_length=1000,
pad_token_id=tokenizer.eos_token_id,
top_k=50,
top_p=0.9,
temperature=0.7,
)
# Decode and print the response
response = tokenizer.decode(response_ids[:, chat_history_ids.shape[-1]:][0], skip_special_tokens=True)
print(f"Chatbot: {response}")