nnsohamnn commited on
Commit
35b7ac6
·
verified ·
1 Parent(s): 1e7f904

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import pickle
5
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+
7
+ # === Load Classifier ===
8
+ classifier_model = tf.keras.models.load_model("classifier_model.keras")
9
+ with open("classifier_tokenizer.pkl", "rb") as f:
10
+ classifier_tokenizer = pickle.load(f)
11
+
12
+ # === Load Text Generator ===
13
+ textgen_model = tf.keras.models.load_model("textgen_model.keras")
14
+ with open("textgen_tokenizer.pkl", "rb") as f:
15
+ textgen_tokenizer = pickle.load(f)
16
+
17
+ # === Settings ===
18
+ CLASS_LABELS = ['label_0', 'label_1', 'label_2'] # Replace with actual labels
19
+ MAX_LEN_CLASSIFIER = 100
20
+ MAX_LEN_TEXTGEN = 50
21
+
22
+
23
+ # === Classifier Inference ===
24
+ def classify_text(text):
25
+ seq = classifier_tokenizer.texts_to_sequences([text])
26
+ padded = pad_sequences(seq, maxlen=MAX_LEN_CLASSIFIER)
27
+ preds = classifier_model.predict(padded, verbose=0)[0]
28
+ return {CLASS_LABELS[i]: float(preds[i]) for i in range(len(CLASS_LABELS))}
29
+
30
+
31
+ # === Text Generation with Top-k Sampling ===
32
+ def generate_text(seed_text, next_words=30, k=10, temperature=0.9):
33
+ recent_words = set()
34
+ for _ in range(next_words):
35
+ token_list = textgen_tokenizer.texts_to_sequences([seed_text])[0]
36
+ token_list = pad_sequences([token_list], maxlen=MAX_LEN_TEXTGEN-1, padding='pre')
37
+ predicted_probs = textgen_model.predict(token_list, verbose=0)[0]
38
+
39
+ scaled_probs = np.log(predicted_probs + 1e-10) / temperature
40
+ exp_probs = np.exp(scaled_probs)
41
+ normalized_probs = exp_probs / np.sum(exp_probs)
42
+
43
+ top_k_indices = np.argsort(normalized_probs)[-k:]
44
+ top_k_probs = normalized_probs[top_k_indices]
45
+ top_k_probs = top_k_probs / np.sum(top_k_probs)
46
+
47
+ predicted_word_index = np.random.choice(top_k_indices, p=top_k_probs)
48
+
49
+ # Avoid repetition
50
+ attempts = 0
51
+ while predicted_word_index in recent_words and attempts < 5:
52
+ predicted_word_index = np.random.choice(top_k_indices, p=top_k_probs)
53
+ attempts += 1
54
+
55
+ recent_words.add(predicted_word_index)
56
+ if len(recent_words) > 10:
57
+ recent_words.pop()
58
+
59
+ output_word = ""
60
+ for word, index in textgen_tokenizer.word_index.items():
61
+ if index == predicted_word_index:
62
+ output_word = word
63
+ break
64
+
65
+ seed_text += " " + output_word
66
+
67
+ return seed_text
68
+
69
+
70
+ # === Gradio UI ===
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("## Dual Model: Text Classifier + Text Generator")
73
+
74
+ with gr.Tab("Text Classification"):
75
+ input_text = gr.Textbox(label="Enter Text", lines=4)
76
+ classify_btn = gr.Button("Classify")
77
+ output_label = gr.Label()
78
+ classify_btn.click(fn=classify_text, inputs=input_text, outputs=output_label)
79
+
80
+ with gr.Tab("Text Generation"):
81
+ seed_text = gr.Textbox(label="Seed Text", lines=4)
82
+ num_words = gr.Slider(5, 100, value=30, label="Words to Generate")
83
+ k_top = gr.Slider(1, 50, value=10, step=1, label="Top-k Sampling")
84
+ temp = gr.Slider(0.5, 1.5, value=0.9, label="Temperature")
85
+ generate_btn = gr.Button("Generate")
86
+ gen_output = gr.Textbox(label="Generated Text", lines=6)
87
+ generate_btn.click(fn=generate_text, inputs=[seed_text, num_words, k_top, temp], outputs=gen_output)
88
+
89
+ demo.launch()