Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM | |
| from huggingface_hub import hf_hub_download | |
| # -------------------- | |
| # Model + Tokenizer | |
| # -------------------- | |
| repo_id = "theguywhosucks/mochaV2" | |
| itos_file = hf_hub_download(repo_id, "itos.json") | |
| stoi_file = hf_hub_download(repo_id, "stoi.json") | |
| with open(stoi_file) as f: | |
| stoi = json.load(f) | |
| with open(itos_file) as f: | |
| itos = json.load(f) | |
| if isinstance(itos, dict): | |
| itos = [itos[str(i)] for i in range(len(itos))] | |
| class SimpleTokenizer: | |
| def __init__(self, stoi, itos): | |
| self.stoi = stoi | |
| self.itos = itos | |
| self.unk_token = "<unk>" if "<unk>" in stoi else itos[0] | |
| def encode(self, text): | |
| return [self.stoi.get(c, self.stoi.get(self.unk_token, 0)) for c in text] | |
| def decode(self, ids): | |
| return "".join([self.itos[i] if i < len(self.itos) else self.unk_token for i in ids]) | |
| tokenizer = SimpleTokenizer(stoi, itos) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| repo_id, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| ).to(device) | |
| model.eval() | |
| # -------------------- | |
| # Inference Function | |
| # -------------------- | |
| def complete_sentence(prompt, max_new_tokens=50, temperature=0.7): | |
| input_ids = torch.tensor([tokenizer.encode(prompt)]).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature | |
| ) | |
| return tokenizer.decode(outputs[0].tolist()) | |
| # -------------------- | |
| # UI / Gradio Blocks | |
| # -------------------- | |
| changelog_text = """ | |
| ### 📜 Changelog | |
| **v1.2.0** | |
| - **NEW** mochaV2 model!! | |
| - Added clean tab-based UI | |
| - Added this Changelog section | |
| - Improved layout and descriptions | |
| **v1.1.0** | |
| - Added sliders for max tokens & temperature | |
| **v1.0.0** | |
| - Initial release with basic completion | |
| """ | |
| about_text = """ | |
| ### ℹ️ About | |
| MochaV2 is a sentence completion model trained from scratch for text generation tasks. | |
| Built with ❤️ by the mocha team. | |
| """ | |
| with gr.Blocks(theme="soft") as demo: | |
| gr.Markdown("# ☕ Mocha AI Sentence Completion") | |
| gr.Markdown("Enter a prompt and let the model generate continuations.") | |
| with gr.Tab("Completion"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", placeholder="Type your text here...") | |
| max_tokens = gr.Slider(10, 200, value=50, step=10, label="Max new tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| submit = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| output = gr.Textbox(label="Completed Text") | |
| submit.click( | |
| fn=complete_sentence, | |
| inputs=[prompt, max_tokens, temperature], | |
| outputs=output | |
| ) | |
| with gr.Tab("Changelog"): | |
| gr.Markdown(changelog_text) | |
| with gr.Tab("About"): | |
| gr.Markdown(about_text) | |
| gr.Markdown("---") | |
| gr.Markdown("⚡ Built by the mocha team") | |
| demo.launch() | |