recursive-bot / app.py
Tyreid0saurus's picture
Update app.py
4347369 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import multiprocessing
model_id = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
def run_generation(prompt, return_dict):
try:
output = generator(
prompt,
max_new_tokens=48,
do_sample=True,
temperature=0.7,
eos_token_id=tokenizer.eos_token_id
)[0]["generated_text"]
return_dict["result"] = output
except Exception as e:
return_dict["result"] = f"GENERATION ERROR: {e}"
def generate_with_hard_timeout(prompt, timeout=15):
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(target=run_generation, args=(prompt, return_dict))
p.start()
p.join(timeout)
if p.is_alive():
p.terminate()
return "ERROR: Generation timed out."
return return_dict["result"]
def chat(input_text):
prompt = input_text + "\n"
try:
output = generate_with_hard_timeout(prompt)
if isinstance(output, str):
reply = output[len(prompt):].strip()
if not reply or reply.isspace():
reply = output.strip()
else:
reply = "ERROR: Unexpected output format."
return reply
except Exception as e:
return f"GENERATION ERROR: {e}"
demo = gr.Interface(
fn=chat,
inputs=gr.Textbox(label="input_text"),
outputs="text",
title="Kairon (Unprimed)",
allow_flagging="never"
)
demo.queue()
demo.launch()