from operator import ge from xml.dom.expatbuilder import theDOMImplementation import gradio as gr import os from transformers import AutoTokenizer, AutoModelForCausalLM import torch # Optional: cache loaded models to avoid reloading every time model_cache = {} def generate(model_name, text, max_new_tokens, top_k): if model_name == "Medium-GPTNeo": model_id = "tniranjan/finetuned_gptneo-base-tinystories-ta_v3" elif model_name == "Small-GPTNeo": model_id = "tniranjan/finetuned_tinystories_33M_tinystories_ta" elif model_name == "Small-LLaMA": model_id = "tniranjan/finetuned_Llama_tinystories_tinystories_ta" # Load model and tokenizer (from cache if available) if model_id not in model_cache: tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) model_cache[model_id] = (tokenizer, model) else: tokenizer, model = model_cache[model_id] inputs = tokenizer(text, return_tensors="pt") # Generate text output = model.generate( **inputs, max_new_tokens=max_new_tokens, top_k=top_k, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) # Decode generated tokens generated_text = tokenizer.decode(output[0], skip_special_tokens=True) return generated_text demo = gr.Interface( generate, title="Kurunkathai: Tinystories in Tamil", description="Generate Tamil stories for toddlers using Kurunkathai. Write the first line or so and click 'Submit' to generate a story.", inputs=[ gr.Dropdown( choices=["Medium-GPTNeo","Small-GPTNeo", "Small-LLaMA"], label="Model", value="Small-GPTNeo", ), gr.Textbox(value="சிறிய குட்டி செல்லி, ஒரு அழகான நாய்க்குட்டியைக் கண்டாள்.", label="Text"), gr.Number(minimum=25, maximum=250, value=100, step=1, label="Max new tokens"), gr.Number(minimum=1, maximum=150, value=35, step=1, label="Top-k"), ], outputs=[ gr.Textbox(label="Generated Story"), ], theme = "Monochrome",) if __name__ == "__main__": demo.launch()