Spaces:
Sleeping
Sleeping
| # Import necessary libraries | |
| import streamlit as st | |
| import mysql.connector | |
| import bcrypt | |
| import datetime | |
| import re | |
| import pytz | |
| import time | |
| # Import transformers library for GPT-2 model | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import torch | |
| # Configure Streamlit page settings | |
| icon='chatbot.png' | |
| st.set_page_config(page_title='GUVI - GPT', page_icon=icon, layout = 'wide') | |
| # Connect to TiDB Cloud database | |
| mydb = mysql.connector.connect( | |
| host='gateway01.ap-southeast-1.prod.aws.tidbcloud.com', | |
| port='4000', | |
| user='3yA2s2Jo7WfjUKk.root', | |
| password='2TqP3zENKkkdjgr0', | |
| database='test' | |
| ) | |
| mycursor = mydb.cursor(buffered=True) | |
| # Create 'GUVI_DB' database and use | |
| mycursor.execute("CREATE DATABASE IF NOT EXISTS GUVI_DB") | |
| mycursor.execute('USE GUVI_DB') | |
| # Create 'users' table if it does not exist | |
| mycursor.execute('''CREATE TABLE IF NOT EXISTS users ( | |
| id INT AUTO_INCREMENT PRIMARY KEY, | |
| username VARCHAR(50) UNIQUE NOT NULL, | |
| password VARCHAR(255) NOT NULL, | |
| email VARCHAR(255) UNIQUE NOT NULL, | |
| registered_date TIMESTAMP, | |
| last_login TIMESTAMP | |
| );''') | |
| # Check if username exists in the database | |
| def username_exists(username): | |
| mycursor.execute("SELECT * FROM users WHERE username = %s", (username,)) | |
| return mycursor.fetchone() is not None | |
| # Check if email exists in the database | |
| def email_exists(email): | |
| mycursor.execute("SELECT * FROM users WHERE email = %s", (email,)) | |
| return mycursor.fetchone() is not None | |
| # Validate email format using regular expressions | |
| def is_valid_email(email): | |
| pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' | |
| return re.match(pattern, email) is not None | |
| # Create a new user in the database | |
| def create_user(username, password, email): | |
| if username_exists(username): | |
| return 'username_exists' | |
| if email_exists(email): | |
| return 'email_exists' | |
| hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()) | |
| registered_date = datetime.datetime.now(pytz.timezone('Asia/Kolkata')) | |
| # Insert user data into 'users' table | |
| mycursor.execute( | |
| "INSERT INTO users (username, password, email, registered_date) VALUES (%s, %s, %s, %s)", | |
| (username, hashed_password, email, registered_date) | |
| ) | |
| mydb.commit() | |
| return 'success' | |
| # Verify user credentials | |
| def verify_user(username, password): | |
| mycursor.execute("SELECT password FROM users WHERE username = %s", (username,)) | |
| record = mycursor.fetchone() | |
| if record and bcrypt.checkpw(password.encode('utf-8'), record[0].encode('utf-8')): | |
| # Update last login timestamp | |
| mycursor.execute("UPDATE users SET last_login = %s WHERE username = %s", (datetime.datetime.now(pytz.timezone('Asia/Kolkata')), username)) | |
| mydb.commit() | |
| return True | |
| return False | |
| # Reset user password | |
| def reset_password(username, new_password): | |
| hashed_password = bcrypt.hashpw(new_password.encode('utf-8'), bcrypt.gensalt()) | |
| # Update password in 'users' table | |
| mycursor.execute( | |
| "UPDATE users SET password = %s WHERE username = %s", | |
| (hashed_password, username) | |
| ) | |
| mydb.commit() | |
| # Session state management | |
| if 'sign_up_successful' not in st.session_state: | |
| st.session_state.sign_up_successful = False | |
| if 'login_successful' not in st.session_state: | |
| st.session_state.login_successful = False | |
| if 'reset_password' not in st.session_state: | |
| st.session_state.reset_password = False | |
| if 'username' not in st.session_state: | |
| st.session_state.username = '' | |
| if 'current_page' not in st.session_state: | |
| st.session_state.current_page = 'login' | |
| # Login form | |
| def login(): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| with st.form(key='login', clear_on_submit=True): | |
| st.subheader(':blue[**Login**]') | |
| st.write("Enter your username and password below.") | |
| # Input fields for username and password | |
| username = st.text_input(label='Username', placeholder='Enter Your Username') | |
| password = st.text_input(label='Password', placeholder='Enter Your Password', type='password') | |
| if st.form_submit_button('Login'): | |
| if not username or not password: | |
| st.error("Please fill out all fields.") | |
| elif verify_user(username, password): | |
| st.session_state.login_successful = True | |
| st.session_state.username = username | |
| st.session_state.current_page = 'home' | |
| st.rerun() | |
| else: | |
| st.error("Incorrect username or password. If you don't have an account, please sign up.") | |
| with col2: | |
| st.image('login.png', width = 320) | |
| # Display sign-up and reset password button | |
| if not st.session_state.login_successful: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write(":red[New user?]") | |
| if st.button('Sign Up'): | |
| st.session_state.current_page = 'sign_up' | |
| st.rerun() | |
| with col2: | |
| st.write(":red[Forgot Password?]") | |
| if st.button('Reset Password'): | |
| st.session_state.current_page = 'reset_password' | |
| st.rerun() | |
| # Sign-up form | |
| def signup(): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| with st.form(key='signup', clear_on_submit=True): | |
| st.subheader(':blue[**Sign Up**]') | |
| st.write("Enter the required fields to create a new account.") | |
| # Input fields for email, username, and password | |
| email = st.text_input(label='Email', placeholder='Enter Your Email') | |
| username = st.text_input(label='Username', placeholder='Enter Your Username') | |
| password = st.text_input(label='Password', placeholder='Enter Your Password', type='password') | |
| re_password = st.text_input(label='Confirm Password', placeholder='Confirm Your Password', type='password') | |
| if st.form_submit_button('Sign Up'): | |
| if not email or not username or not password or not re_password: | |
| st.error("Please fill out all fields.") | |
| elif not is_valid_email(email): | |
| st.error("Please enter a valid email address.") | |
| elif len(password) <= 3: | |
| st.error("Password too short") | |
| elif password != re_password: | |
| st.error("Passwords do not match. Please re-enter.") | |
| else: | |
| result = create_user(username, password, email) | |
| if result == 'username_exists': | |
| st.error("Username already registered. Please use a different username.") | |
| elif result == 'email_exists': | |
| st.error("Email already registered. Please use a different email.") | |
| elif result == 'success': | |
| st.success(f"Username {username} created successfully! Please login.") | |
| st.session_state.sign_up_successful = True | |
| else: | |
| st.error("Failed to create user. Please try again later.") | |
| if st.session_state.sign_up_successful: | |
| if st.button('Go to Login'): | |
| st.session_state.current_page = 'login' | |
| st.rerun() | |
| # Reset password form | |
| def reset_password_page(): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| with st.form(key='reset_password', clear_on_submit=True): | |
| st.subheader(':blue[Reset Password]') | |
| st.write("Enter your username and new password below.") | |
| # Input fields for username and new password | |
| username = st.text_input(label='Username', value='') | |
| new_password = st.text_input(label='New Password', type='password') | |
| re_password = st.text_input(label='Confirm New Password', type='password') | |
| if st.form_submit_button('Reset Password'): | |
| if not username: | |
| st.error("Please enter your username.") | |
| elif not username_exists(username): | |
| st.error("Username not found. Please enter a valid username.") | |
| elif not new_password or not re_password: | |
| st.error("Please fill out all fields.") | |
| elif len(new_password) <= 3: | |
| st.error("Password too short") | |
| elif new_password != re_password: | |
| st.error("Passwords do not match. Please re-enter.") | |
| else: | |
| reset_password(username, new_password) | |
| st.success("Password reset successfully. Please login with your new password.") | |
| st.session_state.current_page = 'login' | |
| # Button to return to login page | |
| st.write('Return to Login page') | |
| if st.button('Login'): | |
| st.session_state.current_page = 'login' | |
| st.rerun() | |
| # Load the fine-tuned model and tokenizer | |
| model_name_or_path = "./fine_tuned_model" | |
| model = GPT2LMHeadModel.from_pretrained(model_name_or_path) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path) | |
| # Set the pad_token to eos_token if it's not already set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Move the model to GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Define the text generation function | |
| def generate_text(model, tokenizer, seed_text, max_length = 100, temperature = 1.0, num_return_sequences = 1): | |
| # Tokenize the input text with padding | |
| inputs = tokenizer(seed_text, return_tensors='pt', padding=True, truncation=True) | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| # Generate text | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_length=max_length, | |
| temperature=temperature, | |
| num_return_sequences=num_return_sequences, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.1, | |
| pad_token_id=tokenizer.eos_token_id # Ensure padding token is set to eos_token_id | |
| ) | |
| # Decode the generated text | |
| generated_texts = [] | |
| for i in range(num_return_sequences): | |
| generated_text = tokenizer.decode(output[i], skip_special_tokens=True) | |
| generated_texts.append(generated_text) | |
| return generated_texts | |
| #home page | |
| def home_page(): | |
| st.title("GUVI TEXT GENERATOR APPLICATION") | |
| st.markdown("Enter the prompt below") | |
| with st.sidebar: | |
| st.title(f"Welcome, {st.session_state.username}!") | |
| st.markdown('<br>',unsafe_allow_html=True) | |
| st.write("### Example Prompts") | |
| st.markdown(''' Guvi is an ''',unsafe_allow_html=True) | |
| st.markdown('<br>',unsafe_allow_html=True) | |
| max=st.slider('Select MAX words',10,250) | |
| st.markdown('<br>',unsafe_allow_html=True) | |
| st.image('Guvi.jpeg', width = 200) | |
| if st.button("Logout"): | |
| st.session_state.clear() | |
| st.session_state.current_page = 'login' | |
| st.rerun() | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if prompt := st.chat_input("Enter the prompt here!"): | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("assistant"): | |
| response = st.write_stream(generate_text(model, tokenizer, seed_text=prompt, max_length=max, temperature=1.0, num_return_sequences=1)) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| with col2: | |
| st.image('chatbot.png', width = 500) | |
| # Display appropriate page based on session state | |
| if st.session_state.current_page == 'home': | |
| home_page() | |
| elif st.session_state.current_page == 'login': | |
| login() | |
| elif st.session_state.current_page == 'sign_up': | |
| signup() | |
| elif st.session_state.current_page == 'reset_password': | |
| reset_password_page() | |