|
|
|
|
|
import gradio as gr |
|
|
from deep_translator import GoogleTranslator |
|
|
from langdetect import detect |
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
import torch |
|
|
import re |
|
|
|
|
|
MODEL_DIR = "./fine_tuned_model" |
|
|
|
|
|
def load_model(): |
|
|
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR, local_files_only=True) |
|
|
model = GPT2LMHeadModel.from_pretrained(MODEL_DIR, local_files_only=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device).eval() |
|
|
return tokenizer, model, device |
|
|
|
|
|
tokenizer, model, device = load_model() |
|
|
|
|
|
def to_en(text): |
|
|
try: |
|
|
lang = detect(text) |
|
|
except Exception: |
|
|
lang = "en" |
|
|
if lang == "en": |
|
|
return text, "en" |
|
|
translated_text = GoogleTranslator(source=lang, target="en").translate(text) |
|
|
|
|
|
return translated_text if translated_text is not None else text, lang |
|
|
|
|
|
def from_en(text, tgt): |
|
|
if tgt == "en": |
|
|
return text |
|
|
translated_text = GoogleTranslator(source="en", target=tgt).translate(text) |
|
|
|
|
|
return translated_text if translated_text is not None else text |
|
|
|
|
|
def generate(prompt, max_new_tokens=120, temperature=0.8): |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
top_k=50, |
|
|
top_p=0.95, |
|
|
temperature=temperature, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
) |
|
|
return tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
def post_process_generated_text(text, prompt): |
|
|
|
|
|
cleaned_text = text.replace(prompt, "").strip() |
|
|
|
|
|
|
|
|
words = cleaned_text.split() |
|
|
if not words: |
|
|
return "" |
|
|
cleaned_words = [words[0]] |
|
|
for i in range(1, len(words)): |
|
|
if words[i].lower() != words[i-1].lower(): |
|
|
cleaned_words.append(words[i]) |
|
|
return " ".join(cleaned_words) |
|
|
|
|
|
|
|
|
def recommend_course(t): |
|
|
t = t.lower() |
|
|
if "python" in t: return "π Python Programming β Beginner to Advanced" |
|
|
if "data science" in t: return "π Data Science Master Program" |
|
|
if "ai" in t or "machine learning" in t or "ml" in t: return "π€ AI & Machine Learning with Real Projects" |
|
|
if "web" in t or "full stack" in t or "javascript" in t or "react" in t: return "π Full Stack Web Development" |
|
|
if "java" in t: return "β Java Programming Essentials" |
|
|
return None |
|
|
|
|
|
def chat(user_input, history): |
|
|
en, lang = to_en(user_input) |
|
|
course = recommend_course(en) |
|
|
if course: |
|
|
en_resp = f"I recommend you check out: {course}" |
|
|
else: |
|
|
|
|
|
prompt = f"User: {en}\nAssistant:" |
|
|
if any(keyword in en.lower() for keyword in ["what is", "tell me about"]): |
|
|
prompt = f"User: {en}\nAssistant: Here is information about {en.lower().replace('what is', '').replace('tell me about', '').strip()}:\n" |
|
|
elif "recommend" in en.lower(): |
|
|
prompt = f"User: {en}\nAssistant: Based on your request, here is a recommendation:\n" |
|
|
|
|
|
|
|
|
en_resp = generate(prompt) |
|
|
|
|
|
|
|
|
en_resp = post_process_generated_text(en_resp, prompt) |
|
|
|
|
|
if en_resp.startswith(prompt): |
|
|
en_resp = en_resp[len(prompt):].strip() |
|
|
|
|
|
final = from_en(en_resp, lang) |
|
|
history = history + [(user_input, final)] |
|
|
return history, history |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π Multilingual GPT-2 Chatbot") |
|
|
chatbot = gr.Chatbot(height=420) |
|
|
msg = gr.Textbox(label="Your Message", placeholder="Type here...") |
|
|
clear = gr.Button("ποΈ Clear") |
|
|
state = gr.State([]) |
|
|
msg.submit(chat, [msg, state], [chatbot, state]) |
|
|
clear.click(lambda: ([], []), None, [chatbot, state], queue=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|