kurunkathai / app.py
tniranjan's picture
Update app.py
ac4bdc5 verified
from operator import ge
from xml.dom.expatbuilder import theDOMImplementation
import gradio as gr
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Optional: cache loaded models to avoid reloading every time
model_cache = {}
def generate(model_name, text, max_new_tokens, top_k):
if model_name == "Medium-GPTNeo":
model_id = "tniranjan/finetuned_gptneo-base-tinystories-ta_v3"
elif model_name == "Small-GPTNeo":
model_id = "tniranjan/finetuned_tinystories_33M_tinystories_ta"
elif model_name == "Small-LLaMA":
model_id = "tniranjan/finetuned_Llama_tinystories_tinystories_ta"
# Load model and tokenizer (from cache if available)
if model_id not in model_cache:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
model_cache[model_id] = (tokenizer, model)
else:
tokenizer, model = model_cache[model_id]
inputs = tokenizer(text, return_tensors="pt")
# Generate text
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
top_k=top_k,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
# Decode generated tokens
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
demo = gr.Interface(
generate,
title="Kurunkathai: Tinystories in Tamil",
description="Generate Tamil stories for toddlers using Kurunkathai. Write the first line or so and click 'Submit' to generate a story.",
inputs=[
gr.Dropdown(
choices=["Medium-GPTNeo","Small-GPTNeo", "Small-LLaMA"],
label="Model",
value="Small-GPTNeo",
),
gr.Textbox(value="சிறிய குட்டி செல்லி, ஒரு அழகான நாய்க்குட்டியைக் கண்டாள்.", label="Text"),
gr.Number(minimum=25, maximum=250, value=100, step=1, label="Max new tokens"),
gr.Number(minimum=1, maximum=150, value=35, step=1, label="Top-k"),
],
outputs=[
gr.Textbox(label="Generated Story"),
],
theme = "Monochrome",)
if __name__ == "__main__":
demo.launch()