Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Pothana Base 300M — Telugu Text Generation Demo | |
| ================================================ | |
| Gradio app for the Pothana Base 300M Telugu language model. | |
| Handles raw Telugu text input with automatic Morfessor segmentation. | |
| Developed by Dvitva AI. | |
| """ | |
| import re | |
| import torch | |
| import morfessor | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = "dvitvaai/pothana-base-300M" | |
| TELUGU_WORD_RE = re.compile(r"[\u0C00-\u0C7F]+") | |
| # --------------------------------------------------------------------------- | |
| # Morfessor segmentation (from inference.py) | |
| # --------------------------------------------------------------------------- | |
| def segment_text(text: str, morf_model, separator: str = "@@") -> str: | |
| """Segment raw text using Morfessor with @@ continuation markers. | |
| - Pure Telugu words -> Morfessor morpheme segments with @@ boundaries | |
| - Pure non-Telugu words -> kept as-is | |
| - Mixed-script tokens (e.g. "2024లో") -> split at script boundary with @@ | |
| """ | |
| tokens = text.split() | |
| seg_tokens = [] | |
| for token in tokens: | |
| if TELUGU_WORD_RE.fullmatch(token): | |
| # Pure Telugu word — segment with Morfessor | |
| segments = morf_model.viterbi_segment(token)[0] | |
| for i, seg in enumerate(segments): | |
| if i < len(segments) - 1: | |
| seg_tokens.append(seg + separator) | |
| else: | |
| seg_tokens.append(seg) | |
| elif TELUGU_WORD_RE.search(token): | |
| # Mixed-script token — split at Telugu/non-Telugu boundaries | |
| parts = re.split(r"([\u0C00-\u0C7F]+)", token) | |
| parts = [p for p in parts if p] | |
| for part_idx, part in enumerate(parts): | |
| is_last_part = (part_idx == len(parts) - 1) | |
| if TELUGU_WORD_RE.fullmatch(part): | |
| segments = morf_model.viterbi_segment(part)[0] | |
| for i, seg in enumerate(segments): | |
| if i < len(segments) - 1: | |
| seg_tokens.append(seg + separator) | |
| else: | |
| if not is_last_part: | |
| seg_tokens.append(seg + separator) | |
| else: | |
| seg_tokens.append(seg) | |
| else: | |
| if not is_last_part: | |
| seg_tokens.append(part + separator) | |
| else: | |
| seg_tokens.append(part) | |
| else: | |
| # Pure non-Telugu word — keep as-is | |
| seg_tokens.append(token) | |
| return " ".join(seg_tokens) | |
| # --------------------------------------------------------------------------- | |
| # Load model, tokenizer, and Morfessor at startup | |
| # --------------------------------------------------------------------------- | |
| print("Loading Morfessor model...") | |
| morf_path = hf_hub_download(repo_id=MODEL_ID, filename="morfessor_telugu.bin") | |
| morf_io = morfessor.MorfessorIO() | |
| morf_model = morf_io.read_binary_model_file(morf_path) | |
| print("Loading model and tokenizer...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, torch_dtype=dtype | |
| ).to(device).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| print(f"Model loaded on {device} ({dtype}). Ready!") | |
| # --------------------------------------------------------------------------- | |
| # Generation function | |
| # --------------------------------------------------------------------------- | |
| def generate( | |
| prompt: str, | |
| max_new_tokens: int = 200, | |
| temperature: float = 0.8, | |
| top_k: int = 50, | |
| top_p: float = 0.95, | |
| repetition_penalty: float = 1.1, | |
| ): | |
| """Generate Telugu text from a raw text prompt.""" | |
| if not prompt.strip(): | |
| return "" | |
| # Step 1: Morfessor segmentation | |
| segmented = segment_text(prompt.strip(), morf_model, separator="@@") | |
| # Step 2: Tokenize | |
| inputs = tokenizer(segmented, return_tensors="pt").to(device) | |
| # Step 3: Generate | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=temperature, | |
| top_k=int(top_k), | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| ) | |
| # Step 4: Decode (TeluguTokenizer handles @@ stripping) | |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return text | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| examples = [ | |
| ["తెలుగు భాష చాలా"], | |
| ["భారతదేశం ఒక"], | |
| ["ఒకప్పుడు ఒక రాజు"], | |
| ["హైదరాబాద్ నగరంలో"], | |
| ["విద్యార్థులు పరీక్షలకు"], | |
| ] | |
| demo = gr.Interface( | |
| fn=generate, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Prompt", | |
| placeholder="తెలుగులో టెక్స్ట్ టైప్ చేయండి...", | |
| lines=3, | |
| ), | |
| gr.Slider( | |
| minimum=50, maximum=500, value=200, step=10, | |
| label="Max New Tokens", | |
| ), | |
| gr.Slider( | |
| minimum=0.1, maximum=1.5, value=0.8, step=0.05, | |
| label="Temperature", | |
| ), | |
| gr.Slider( | |
| minimum=10, maximum=100, value=50, step=5, | |
| label="Top-k", | |
| ), | |
| gr.Slider( | |
| minimum=0.5, maximum=1.0, value=0.95, step=0.05, | |
| label="Top-p", | |
| ), | |
| gr.Slider( | |
| minimum=1.0, maximum=1.5, value=1.1, step=0.05, | |
| label="Repetition Penalty", | |
| ), | |
| ], | |
| outputs=gr.Textbox(label="Generated Text", lines=8), | |
| title="Pothana Base 300M", | |
| description=( | |
| "**Telugu text generation** powered by Pothana Base 300M, " | |
| "a ~300M parameter LLaMA-style model trained from scratch on Telugu text. " | |
| "Named after [Bammera Pothana](https://en.wikipedia.org/wiki/Bammera_Pothana), " | |
| "the celebrated 15th-century Telugu poet.\n\n" | |
| "This is a **base model** (text completion, not chat/instruction-following). " | |
| "Type a Telugu prompt and the model will continue it.\n\n" | |
| "Developed by **[Dvitva AI](https://dvitva.ai)**." | |
| ), | |
| examples=examples, | |
| cache_examples=False, | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |