code_legalist / app.py
rxhulshxrmx's picture
Update app.py
3e20c09 verified
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
class AIAdvocateChatbot:
def __init__(self, model_name="Meta-llama/Llama-2-7b-chat-hf"):
"""
Initialize chatbot with an open-source model.
"""
# Check if GPU is available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model
print("Loading model and tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
# Create generation pipeline
print("Setting up text-generation pipeline...")
self.generator = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if self.device == "cuda" else -1,
max_length=1024,
temperature=0.7, # Adjust for response creativity
top_p=0.9 # Top-p sampling for better conversational results
)
def generate_response(self, message, history):
"""
Generate conversational response based on user message and chat history.
"""
try:
# Combine chat history into a single prompt
context = "\n".join([
f"Human: {msg[0]}\nAssistant: {msg[1]}"
for msg in history
])
# Prepare the full conversation prompt
full_prompt = f"{context}\nHuman: {message}\nAssistant:"
# Generate a response
print("Generating response...")
response = self.generator(
full_prompt,
max_new_tokens=200,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id
)[0]['generated_text']
# Extract the new assistant response
assistant_response = response.split("Assistant:")[-1].strip()
return assistant_response
except Exception as e:
return f"Error generating response: {str(e)}"
def create_chatbot_interface():
"""
Create a Gradio interface for the chatbot.
"""
# Initialize the chatbot instance
chatbot = AIAdvocateChatbot()
# Define Gradio interface
demo = gr.ChatInterface(
chatbot.generate_response,
title="🤖 AI Advocate Chatbot",
description=(
"An advanced conversational AI chatbot. "
"Designed for engaging, human-like interactions and future integration "
"into Retrieval-Augmented Generation systems."
)
)
return demo
def main():
"""
Launch the Gradio interface.
"""
# Create and launch the Gradio app
interface = create_chatbot_interface()
interface.launch(share=True)
if __name__ == "__main__":
main()