zaid002 commited on
Commit
2196aef
Β·
verified Β·
1 Parent(s): 0c0f340

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app_gradio.py
2
+ import gradio as gr
3
+ from deep_translator import GoogleTranslator
4
+ from langdetect import detect
5
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
6
+ import torch
7
+ import re # Import regex for post-processing
8
+
9
+ MODEL_DIR = "./fine_tuned_model"
10
+
11
+ def load_model():
12
+ tokenizer = GPT2Tokenizer.from_pretrained(MODEL_DIR, local_files_only=True)
13
+ model = GPT2LMHeadModel.from_pretrained(MODEL_DIR, local_files_only=True)
14
+ if tokenizer.pad_token is None:
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+ model.config.pad_token_id = tokenizer.pad_token_id
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model.to(device).eval()
19
+ return tokenizer, model, device
20
+
21
+ tokenizer, model, device = load_model()
22
+
23
+ def to_en(text):
24
+ try:
25
+ lang = detect(text)
26
+ except Exception:
27
+ lang = "en"
28
+ if lang == "en":
29
+ return text, "en"
30
+ translated_text = GoogleTranslator(source=lang, target="en").translate(text)
31
+ # Handle potential None return from translator
32
+ return translated_text if translated_text is not None else text, lang
33
+
34
+ def from_en(text, tgt):
35
+ if tgt == "en":
36
+ return text
37
+ translated_text = GoogleTranslator(source="en", target=tgt).translate(text)
38
+ # Handle potential None return from translator
39
+ return translated_text if translated_text is not None else text
40
+
41
+ def generate(prompt, max_new_tokens=120, temperature=0.8):
42
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
43
+ with torch.no_grad():
44
+ out = model.generate(
45
+ **inputs,
46
+ max_new_tokens=max_new_tokens,
47
+ do_sample=True,
48
+ top_k=50,
49
+ top_p=0.95,
50
+ temperature=temperature,
51
+ eos_token_id=tokenizer.eos_token_id,
52
+ pad_token_id=tokenizer.pad_token_id,
53
+ )
54
+ return tokenizer.decode(out[0], skip_special_tokens=True)
55
+
56
+ def post_process_generated_text(text, prompt):
57
+ # Simple post-processing to clean up potential repetitions or unwanted tokens
58
+ cleaned_text = text.replace(prompt, "").strip() # Remove the prompt from the output
59
+
60
+ # Remove consecutive repeated words - improved
61
+ words = cleaned_text.split()
62
+ if not words:
63
+ return ""
64
+ cleaned_words = [words[0]]
65
+ for i in range(1, len(words)):
66
+ if words[i].lower() != words[i-1].lower(): # Case-insensitive comparison
67
+ cleaned_words.append(words[i])
68
+ return " ".join(cleaned_words)
69
+
70
+
71
+ def recommend_course(t):
72
+ t = t.lower()
73
+ if "python" in t: return "🐍 Python Programming – Beginner to Advanced"
74
+ if "data science" in t: return "πŸ“Š Data Science Master Program"
75
+ if "ai" in t or "machine learning" in t or "ml" in t: return "πŸ€– AI & Machine Learning with Real Projects"
76
+ if "web" in t or "full stack" in t or "javascript" in t or "react" in t: return "🌐 Full Stack Web Development"
77
+ if "java" in t: return "β˜• Java Programming Essentials"
78
+ return None
79
+
80
+ def chat(user_input, history):
81
+ en, lang = to_en(user_input)
82
+ course = recommend_course(en)
83
+ if course:
84
+ en_resp = f"I recommend you check out: {course}"
85
+ else:
86
+ # Modify prompt to encourage structured output based on keywords
87
+ prompt = f"User: {en}\nAssistant:"
88
+ if any(keyword in en.lower() for keyword in ["what is", "tell me about"]):
89
+ prompt = f"User: {en}\nAssistant: Here is information about {en.lower().replace('what is', '').replace('tell me about', '').strip()}:\n"
90
+ elif "recommend" in en.lower():
91
+ prompt = f"User: {en}\nAssistant: Based on your request, here is a recommendation:\n"
92
+
93
+
94
+ en_resp = generate(prompt)
95
+
96
+ # Apply post-processing to clean the generated text
97
+ en_resp = post_process_generated_text(en_resp, prompt)
98
+
99
+ if en_resp.startswith(prompt):
100
+ en_resp = en_resp[len(prompt):].strip()
101
+
102
+ final = from_en(en_resp, lang)
103
+ history = history + [(user_input, final)]
104
+ return history, history
105
+
106
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
107
+ gr.Markdown("# 🌐 Multilingual GPT-2 Chatbot")
108
+ chatbot = gr.Chatbot(height=420)
109
+ msg = gr.Textbox(label="Your Message", placeholder="Type here...")
110
+ clear = gr.Button("πŸ—‘οΈ Clear")
111
+ state = gr.State([])
112
+ msg.submit(chat, [msg, state], [chatbot, state])
113
+ clear.click(lambda: ([], []), None, [chatbot, state], queue=False)
114
+
115
+ # You can run this in a separate cell using !python app_gradio.py if needed,
116
+ # but running it directly in the notebook cell is also possible.
117
+ # if __name__ == "__main__":
118
+ # demo.launch()