File size: 11,703 Bytes
cf6606c
 
c035add
 
 
3dfe622
c035add
3dfe622
 
 
 
 
 
c035add
 
3dfe622
c035add
3dfe622
c035add
 
 
 
 
 
 
 
 
 
 
 
 
 
3dfe622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c035add
 
3dfe622
 
c035add
 
 
3dfe622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c035add
 
 
3dfe622
c035add
 
 
3dfe622
 
c035add
 
3dfe622
c035add
3dfe622
c035add
3dfe622
1baf5eb
3dfe622
 
c035add
 
3dfe622
 
 
 
c035add
3dfe622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c035add
 
3dfe622
 
 
 
 
 
 
 
 
 
c035add
 
 
 
 
 
 
 
 
3dfe622
c035add
 
3dfe622
 
 
d3ad937
3dfe622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c035add
 
 
 
 
3dfe622
 
c035add
 
 
 
 
 
 
 
 
3dfe622
 
 
c035add
 
3dfe622
 
 
c035add
 
3dfe622
c035add
 
3dfe622
c035add
 
 
3dfe622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import os
import time
import re
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,AutoTokenizer, TextDataset, DataCollatorForLanguageModeling
from streamlit_chat import message
from datasets import load_dataset  # تعديل لاستخدام مكتبة datasets
from pathlib import Path
import torch
from PyPDF2 import PdfReader
import requests
from bs4 import BeautifulSoup

# Set page title and icon
st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:")

# Custom CSS for styling chat messages and buttons
st.markdown(
    """
    <style>
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        border-radius: 12px;
        padding: 10px 24px;
    }
    .stTextArea textarea {
        background-color: #f5f5f5;
    }
    .stDownloadButton>button {
        background-color: #4CAF50;
        color: white;
    }
    .stMessageContainer {
        border-radius: 15px;
        padding: 10px;
        margin: 10px 0;
    }
    .stMessage--user {
        background-color: #dfe7f3;
        border-left: 6px solid #006699;
    }
    .stMessage--assistant {
        background-color: #f3f3f3;
        border-left: 6px solid #4CAF50;
    }
    pre {
        background-color: #f5f5f5;
        border-left: 6px solid #dfe7f3;
        padding: 10px;
        font-size: 14px;
        border-radius: 8px;
    }
    </style>
    """,
    unsafe_allow_html=True,
)

# Initialize session state variables
if "generated" not in st.session_state:
    st.session_state["generated"] = []
if "past" not in st.session_state:
    st.session_state["past"] = []
if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
if "model_name" not in st.session_state:
    st.session_state["model_name"] = []
if "total_tokens" not in st.session_state:
    st.session_state["total_tokens"] = []
if "total_cost" not in st.session_state:
    st.session_state["total_cost"] = 0.0
if "chat_data" not in st.session_state:
    st.session_state["chat_data"] = []  # For storing the chat logs
if "uploaded_docs" not in st.session_state:
    st.session_state["uploaded_docs"] = []  # For storing uploaded document content
if "web_data" not in st.session_state:
    st.session_state["web_data"] = []  # For storing web scraped data
if "uploaded_file_path" not in st.session_state:
    st.session_state["uploaded_file_path"] = ""  # Store the path of saved files

# Sidebar - Model Selection, Style Parameters, and Cost Display
st.sidebar.title("Model Selection")
model_name = "gpt2"

# Parameters to adjust the response style and creativity
st.sidebar.title("Response Style Controls")
temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
max_length = st.sidebar.slider("Max Length", min_value=100, max_value=4048, value=400, step=10)

# Load the model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
    model_path = "gpt2"  # المسار المحلي للنموذج
    tokenizer = AutoTokenizer.from_pretrained(model_path, clean_up_tokenization_spaces=True)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    return tokenizer, model

tokenizer, model = load_model_and_tokenizer()
# Function to generate a response using the model with updated generation configuration
    # إعداد متغيرات TrainingArguments مع تحسينات
tokenizer.pad_token = tokenizer.eos_token  # لضمان أن المفكرة تستخدم رمز eos كـ pad token

def generate_response(prompt):
    context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
    inputs = tokenizer(context, return_tensors="pt")
    
    generation_config = {
        "max_length": max_length,
        "temperature": temperature if do_sample else None,
        "top_p": top_p if do_sample else None,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "pad_token_id": tokenizer.eos_token_id,
        "do_sample": do_sample  
    }
    
    outputs = model.generate(
        inputs.input_ids,
        attention_mask=inputs.attention_mask,
        **generation_config
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response
# Set do_sample to True
do_sample = True
# Function to reset the session
def reset_session():
    st.session_state["generated"] = []
    st.session_state["past"] = []
    st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
    st.session_state["model_name"] = []
    st.session_state["total_tokens"] = []
    st.session_state["total_cost"] = 0.0
    st.session_state["chat_data"] = []  # Reset chat logs
    st.session_state["uploaded_docs"] = []  # Reset uploaded docs
    st.session_state["web_data"] = []  # Reset web data


# Reset chat button in sidebar
reset_button = st.sidebar.button("Reset Chat")
if reset_button:
    reset_session()


# Function to save chat logs for later fine-tuning
def save_chat_data(chat_data):
    with open("chat_data.json", "w") as f:
        json.dump(chat_data, f, indent=4)


# Function to handle uploaded text or PDF files and convert PDF to txt
def handle_uploaded_file(uploaded_file):
    dataset_dir = "./dataset"
    dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"

    # Check if the file is a PDF
    if uploaded_file.type == "application/pdf":
        # Read and extract text from the PDF
        pdf_reader = PdfReader(uploaded_file)
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text()

        # Save extracted text as a .txt file
        with open(dataset_path, "w") as f:
            f.write(text)
        st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
    else:
        # If it's a text file, save it as is
        with open(dataset_path, "wb") as f:
            f.write(uploaded_file.getbuffer())
        st.success(f"File saved to {dataset_path}")

    st.session_state["uploaded_file_path"] = str(dataset_path)


# Add a file uploader for various formats
st.sidebar.title("Upload Documents")
uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])

# Process uploaded file
if uploaded_file is not None:
    handle_uploaded_file(uploaded_file)


# Function to fetch and scrape website content
def handle_web_link(url):
    response = requests.get(url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.content, "html.parser")
        text = soup.get_text()
        st.session_state["web_data"].append(text)
        st.success(f"Content from {url} saved successfully!")
    else:
        st.error(f"Failed to retrieve content from {url}. Status code: {response.status_code}")


# Add a text box for entering website links
st.sidebar.title("Add Website Links")
web_link = st.sidebar.text_input("Enter Website URL")

# Process web link
if web_link:
    handle_web_link(web_link)


# Containers for chat history and user input
response_container = st.container()
container = st.container()

with container:
    with st.form(key="user_input_form"):
        user_input = st.text_area("You:", key="user_input", height=100)
        submit_button = st.form_submit_button("Send")

    if submit_button and user_input:
        start_time = time.time()
        output = generate_response(user_input)
        end_time = time.time()
        inference_time = end_time - start_time

        # Append user input and model output to session state
        st.session_state["past"].append(user_input)
        st.session_state["generated"].append(output)
        st.session_state["model_name"].append(model_name)

        # Log chat data for future training
        st.session_state["chat_data"].append(
            {"user_input": user_input, "model_response": output}
        )

        # Save chat data to a file (this could be used later for training)
        save_chat_data(st.session_state["chat_data"])

        # Calculate tokens and cost
      

        # Display chat history
        with response_container:
            for i in range(len(st.session_state["generated"])):
                message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
                message(st.session_state["generated"][i], key=str(i))
               

# Function to fine-tune the model using uploaded dataset
def fine_tune_model():
    uploaded_file_path = st.session_state.get("uploaded_file_path", None)
    if not uploaded_file_path:
        st.warning("يرجى تحميل dataset لتدريب النموذج.")
        return

    # تحميل البيانات النصية أو CSV
    if uploaded_file_path.endswith('.txt'):
        dataset = load_dataset('text', data_files=uploaded_file_path, split='train')
    elif uploaded_file_path.endswith('.csv'):
        dataset = load_dataset('csv', data_files=uploaded_file_path, split='train')

    # معالجة البيانات: تحويل النصوص إلى رموز (tokenization)
    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)

    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

    # إعداد الـ collator لعدم استخدام الـ mask language modeling
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # التحقق مما إذا كان النظام يستخدم GPU أم لا
    use_fp16 = torch.cuda.is_available()  # تفعيل fp16 فقط إذا كان GPU متاحًا

    # إعداد متغيرات TrainingArguments
    training_args = TrainingArguments(
        output_dir='./gpt2-finetuned',  
        overwrite_output_dir=True,  
        num_train_epochs=4,  
        per_device_train_batch_size=3,  
        per_device_eval_batch_size=3,  
        save_steps=500,  
        eval_strategy="steps",  
        eval_steps=500,  
        learning_rate=2e-5,  
        weight_decay=0.01,  
        logging_dir='./logs',  
        logging_steps=100,  
        save_total_limit=3,  
        load_best_model_at_end=True,  
        metric_for_best_model='accuracy',  
        greater_is_better=True,  
        fp16=use_fp16,  # تفعيل fp16 فقط إذا كان GPU متاحًا
        remove_unused_columns=False,  # تعطيل هذا الخيار لحل مشكلة عدم توافق الأعمدة
    )

    # تهيئة الـ Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=tokenized_dataset,
    )

    # البدء في التدريب
    trainer.train()

    st.success("تم إكمال تدريب النموذج بنجاح.")

# واجهة Streamlit لتحميل dataset وبدء التدريب
st.title("Fine-tune GPT-2 Model")

uploaded_file = st.file_uploader("Upload your dataset (TXT or CSV)", type=['txt', 'csv'])
if uploaded_file:
    st.session_state["uploaded_file_path"] = uploaded_file.name
    with open(uploaded_file.name, "wb") as f:
        f.write(uploaded_file.getbuffer())
    st.success(f"File {uploaded_file.name} uploaded successfully.")

if st.button("Start Fine-tuning"):
    fine_tune_model()