Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import pickle
|
| 5 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 6 |
+
|
| 7 |
+
# === Load Classifier ===
|
| 8 |
+
classifier_model = tf.keras.models.load_model("classifier_model.keras")
|
| 9 |
+
with open("classifier_tokenizer.pkl", "rb") as f:
|
| 10 |
+
classifier_tokenizer = pickle.load(f)
|
| 11 |
+
|
| 12 |
+
# === Load Text Generator ===
|
| 13 |
+
textgen_model = tf.keras.models.load_model("textgen_model.keras")
|
| 14 |
+
with open("textgen_tokenizer.pkl", "rb") as f:
|
| 15 |
+
textgen_tokenizer = pickle.load(f)
|
| 16 |
+
|
| 17 |
+
# === Settings ===
|
| 18 |
+
CLASS_LABELS = ['label_0', 'label_1', 'label_2'] # Replace with actual labels
|
| 19 |
+
MAX_LEN_CLASSIFIER = 100
|
| 20 |
+
MAX_LEN_TEXTGEN = 50
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# === Classifier Inference ===
|
| 24 |
+
def classify_text(text):
|
| 25 |
+
seq = classifier_tokenizer.texts_to_sequences([text])
|
| 26 |
+
padded = pad_sequences(seq, maxlen=MAX_LEN_CLASSIFIER)
|
| 27 |
+
preds = classifier_model.predict(padded, verbose=0)[0]
|
| 28 |
+
return {CLASS_LABELS[i]: float(preds[i]) for i in range(len(CLASS_LABELS))}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# === Text Generation with Top-k Sampling ===
|
| 32 |
+
def generate_text(seed_text, next_words=30, k=10, temperature=0.9):
|
| 33 |
+
recent_words = set()
|
| 34 |
+
for _ in range(next_words):
|
| 35 |
+
token_list = textgen_tokenizer.texts_to_sequences([seed_text])[0]
|
| 36 |
+
token_list = pad_sequences([token_list], maxlen=MAX_LEN_TEXTGEN-1, padding='pre')
|
| 37 |
+
predicted_probs = textgen_model.predict(token_list, verbose=0)[0]
|
| 38 |
+
|
| 39 |
+
scaled_probs = np.log(predicted_probs + 1e-10) / temperature
|
| 40 |
+
exp_probs = np.exp(scaled_probs)
|
| 41 |
+
normalized_probs = exp_probs / np.sum(exp_probs)
|
| 42 |
+
|
| 43 |
+
top_k_indices = np.argsort(normalized_probs)[-k:]
|
| 44 |
+
top_k_probs = normalized_probs[top_k_indices]
|
| 45 |
+
top_k_probs = top_k_probs / np.sum(top_k_probs)
|
| 46 |
+
|
| 47 |
+
predicted_word_index = np.random.choice(top_k_indices, p=top_k_probs)
|
| 48 |
+
|
| 49 |
+
# Avoid repetition
|
| 50 |
+
attempts = 0
|
| 51 |
+
while predicted_word_index in recent_words and attempts < 5:
|
| 52 |
+
predicted_word_index = np.random.choice(top_k_indices, p=top_k_probs)
|
| 53 |
+
attempts += 1
|
| 54 |
+
|
| 55 |
+
recent_words.add(predicted_word_index)
|
| 56 |
+
if len(recent_words) > 10:
|
| 57 |
+
recent_words.pop()
|
| 58 |
+
|
| 59 |
+
output_word = ""
|
| 60 |
+
for word, index in textgen_tokenizer.word_index.items():
|
| 61 |
+
if index == predicted_word_index:
|
| 62 |
+
output_word = word
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
seed_text += " " + output_word
|
| 66 |
+
|
| 67 |
+
return seed_text
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# === Gradio UI ===
|
| 71 |
+
with gr.Blocks() as demo:
|
| 72 |
+
gr.Markdown("## Dual Model: Text Classifier + Text Generator")
|
| 73 |
+
|
| 74 |
+
with gr.Tab("Text Classification"):
|
| 75 |
+
input_text = gr.Textbox(label="Enter Text", lines=4)
|
| 76 |
+
classify_btn = gr.Button("Classify")
|
| 77 |
+
output_label = gr.Label()
|
| 78 |
+
classify_btn.click(fn=classify_text, inputs=input_text, outputs=output_label)
|
| 79 |
+
|
| 80 |
+
with gr.Tab("Text Generation"):
|
| 81 |
+
seed_text = gr.Textbox(label="Seed Text", lines=4)
|
| 82 |
+
num_words = gr.Slider(5, 100, value=30, label="Words to Generate")
|
| 83 |
+
k_top = gr.Slider(1, 50, value=10, step=1, label="Top-k Sampling")
|
| 84 |
+
temp = gr.Slider(0.5, 1.5, value=0.9, label="Temperature")
|
| 85 |
+
generate_btn = gr.Button("Generate")
|
| 86 |
+
gen_output = gr.Textbox(label="Generated Text", lines=6)
|
| 87 |
+
generate_btn.click(fn=generate_text, inputs=[seed_text, num_words, k_top, temp], outputs=gen_output)
|
| 88 |
+
|
| 89 |
+
demo.launch()
|