Update app.py
Browse files
app.py
CHANGED
|
@@ -9,12 +9,11 @@ MODEL_FILE = "gpt2-q4_k_m.gguf"
|
|
| 9 |
CACHE_DIR = "./model_cache"
|
| 10 |
MAX_TOKENS = 200
|
| 11 |
|
| 12 |
-
# Initialize model
|
| 13 |
def load_model():
|
| 14 |
"""Download and load GGUF model with proper path handling"""
|
| 15 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 16 |
|
| 17 |
-
# Download model if not cached
|
| 18 |
model_path = hf_hub_download(
|
| 19 |
repo_id=MODEL_REPO,
|
| 20 |
filename=MODEL_FILE,
|
|
@@ -24,7 +23,7 @@ def load_model():
|
|
| 24 |
|
| 25 |
return Llama(
|
| 26 |
model_path=model_path,
|
| 27 |
-
n_ctx=
|
| 28 |
n_threads=4,
|
| 29 |
verbose=False
|
| 30 |
)
|
|
@@ -32,11 +31,14 @@ def load_model():
|
|
| 32 |
# Load model at startup
|
| 33 |
llm = load_model()
|
| 34 |
|
| 35 |
-
# Generation function with
|
| 36 |
def generate_text(prompt, max_tokens=MAX_TOKENS, temp=0.7, top_p=0.95):
|
| 37 |
-
"""Generate text with repetition prevention and
|
| 38 |
if not prompt.strip():
|
| 39 |
-
return "Please enter a valid prompt."
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
try:
|
| 42 |
output = llm(
|
|
@@ -45,7 +47,6 @@ def generate_text(prompt, max_tokens=MAX_TOKENS, temp=0.7, top_p=0.95):
|
|
| 45 |
temperature=temp,
|
| 46 |
top_p=top_p,
|
| 47 |
echo=False,
|
| 48 |
-
# Anti-repetition parameters
|
| 49 |
repeat_penalty=1.2,
|
| 50 |
no_repeat_ngram_size=3
|
| 51 |
)
|
|
@@ -62,12 +63,10 @@ with gr.Blocks(theme="soft") as demo:
|
|
| 62 |
|
| 63 |
with gr.Row():
|
| 64 |
with gr.Column():
|
| 65 |
-
# Input components
|
| 66 |
prompt = gr.Textbox(
|
| 67 |
label="Input Prompt",
|
| 68 |
-
placeholder="Enter your prompt here...",
|
| 69 |
-
lines=5
|
| 70 |
-
min_length=10
|
| 71 |
)
|
| 72 |
max_tokens = gr.Slider(
|
| 73 |
minimum=50,
|
|
@@ -92,16 +91,13 @@ with gr.Blocks(theme="soft") as demo:
|
|
| 92 |
)
|
| 93 |
|
| 94 |
with gr.Column():
|
| 95 |
-
# Output and button
|
| 96 |
output = gr.Textbox(label="Generated Text", lines=10)
|
| 97 |
generate_btn = gr.Button("🚀 Generate", variant="primary")
|
| 98 |
|
| 99 |
-
# Event handler
|
| 100 |
generate_btn.click(
|
| 101 |
fn=generate_text,
|
| 102 |
inputs=[prompt, max_tokens, temp, top_p],
|
| 103 |
outputs=output
|
| 104 |
)
|
| 105 |
|
| 106 |
-
# Launch app
|
| 107 |
demo.launch()
|
|
|
|
| 9 |
CACHE_DIR = "./model_cache"
|
| 10 |
MAX_TOKENS = 200
|
| 11 |
|
| 12 |
+
# Initialize model
|
| 13 |
def load_model():
|
| 14 |
"""Download and load GGUF model with proper path handling"""
|
| 15 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 16 |
|
|
|
|
| 17 |
model_path = hf_hub_download(
|
| 18 |
repo_id=MODEL_REPO,
|
| 19 |
filename=MODEL_FILE,
|
|
|
|
| 23 |
|
| 24 |
return Llama(
|
| 25 |
model_path=model_path,
|
| 26 |
+
n_ctx=1024, # Match model's training context length
|
| 27 |
n_threads=4,
|
| 28 |
verbose=False
|
| 29 |
)
|
|
|
|
| 31 |
# Load model at startup
|
| 32 |
llm = load_model()
|
| 33 |
|
| 34 |
+
# Generation function with validation
|
| 35 |
def generate_text(prompt, max_tokens=MAX_TOKENS, temp=0.7, top_p=0.95):
|
| 36 |
+
"""Generate text with repetition prevention and input validation"""
|
| 37 |
if not prompt.strip():
|
| 38 |
+
return "⚠️ Please enter a valid prompt."
|
| 39 |
+
|
| 40 |
+
if len(prompt.split()) < 3: # Minimum word count
|
| 41 |
+
return "⚠️ Please enter at least 3 words for better results."
|
| 42 |
|
| 43 |
try:
|
| 44 |
output = llm(
|
|
|
|
| 47 |
temperature=temp,
|
| 48 |
top_p=top_p,
|
| 49 |
echo=False,
|
|
|
|
| 50 |
repeat_penalty=1.2,
|
| 51 |
no_repeat_ngram_size=3
|
| 52 |
)
|
|
|
|
| 63 |
|
| 64 |
with gr.Row():
|
| 65 |
with gr.Column():
|
|
|
|
| 66 |
prompt = gr.Textbox(
|
| 67 |
label="Input Prompt",
|
| 68 |
+
placeholder="Enter your prompt here... (at least 3 words)",
|
| 69 |
+
lines=5
|
|
|
|
| 70 |
)
|
| 71 |
max_tokens = gr.Slider(
|
| 72 |
minimum=50,
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
with gr.Column():
|
|
|
|
| 94 |
output = gr.Textbox(label="Generated Text", lines=10)
|
| 95 |
generate_btn = gr.Button("🚀 Generate", variant="primary")
|
| 96 |
|
|
|
|
| 97 |
generate_btn.click(
|
| 98 |
fn=generate_text,
|
| 99 |
inputs=[prompt, max_tokens, temp, top_p],
|
| 100 |
outputs=output
|
| 101 |
)
|
| 102 |
|
|
|
|
| 103 |
demo.launch()
|