Spaces:
Sleeping
Sleeping
| from operator import ge | |
| from xml.dom.expatbuilder import theDOMImplementation | |
| import gradio as gr | |
| import os | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Optional: cache loaded models to avoid reloading every time | |
| model_cache = {} | |
| def generate(model_name, text, max_new_tokens, top_k): | |
| if model_name == "Medium-GPTNeo": | |
| model_id = "tniranjan/finetuned_gptneo-base-tinystories-ta_v3" | |
| elif model_name == "Small-GPTNeo": | |
| model_id = "tniranjan/finetuned_tinystories_33M_tinystories_ta" | |
| elif model_name == "Small-LLaMA": | |
| model_id = "tniranjan/finetuned_Llama_tinystories_tinystories_ta" | |
| # Load model and tokenizer (from cache if available) | |
| if model_id not in model_cache: | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| model_cache[model_id] = (tokenizer, model) | |
| else: | |
| tokenizer, model = model_cache[model_id] | |
| inputs = tokenizer(text, return_tensors="pt") | |
| # Generate text | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| top_k=top_k, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode generated tokens | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return generated_text | |
| demo = gr.Interface( | |
| generate, | |
| title="Kurunkathai: Tinystories in Tamil", | |
| description="Generate Tamil stories for toddlers using Kurunkathai. Write the first line or so and click 'Submit' to generate a story.", | |
| inputs=[ | |
| gr.Dropdown( | |
| choices=["Medium-GPTNeo","Small-GPTNeo", "Small-LLaMA"], | |
| label="Model", | |
| value="Small-GPTNeo", | |
| ), | |
| gr.Textbox(value="சிறிய குட்டி செல்லி, ஒரு அழகான நாய்க்குட்டியைக் கண்டாள்.", label="Text"), | |
| gr.Number(minimum=25, maximum=250, value=100, step=1, label="Max new tokens"), | |
| gr.Number(minimum=1, maximum=150, value=35, step=1, label="Top-k"), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Generated Story"), | |
| ], | |
| theme = "Monochrome",) | |
| if __name__ == "__main__": | |
| demo.launch() | |