Spaces:
Sleeping
Sleeping
| import json | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM | |
| from huggingface_hub import hf_hub_download | |
| # HF repo containing your model | |
| repo_id = "theguywhosucks/mochaV1-base" | |
| # Download tokenizer files | |
| itos_file = hf_hub_download(repo_id, "itos.json") | |
| stoi_file = hf_hub_download(repo_id, "stoi.json") | |
| with open(stoi_file) as f: | |
| stoi = json.load(f) | |
| with open(itos_file) as f: | |
| itos = json.load(f) | |
| # Convert itos dict -> list if needed | |
| if isinstance(itos, dict): | |
| itos = [itos[str(i)] for i in range(len(itos))] | |
| # Tokenizer | |
| class SimpleTokenizer: | |
| def __init__(self, stoi, itos): | |
| self.stoi = stoi | |
| self.itos = itos | |
| self.unk_token = "<unk>" if "<unk>" in stoi else itos[0] | |
| def encode(self, text): | |
| return [self.stoi.get(c, self.stoi.get(self.unk_token, 0)) for c in text] | |
| def decode(self, ids): | |
| return "".join([self.itos[i] if i < len(self.itos) else self.unk_token for i in ids]) | |
| tokenizer = SimpleTokenizer(stoi, itos) | |
| # Load model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| repo_id, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| ) | |
| model.to(device) | |
| model.eval() | |
| # Gradio function | |
| def complete_sentence(prompt, max_new_tokens=50, temperature=0.7): | |
| input_ids = torch.tensor([tokenizer.encode(prompt)]).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature | |
| ) | |
| return tokenizer.decode(outputs[0].tolist()) | |
| # Launch Gradio app | |
| gr.Interface( | |
| fn=complete_sentence, | |
| inputs=[ | |
| gr.Textbox(label="Prompt"), | |
| gr.Slider(10, 200, value=50, step=10, label="Max new tokens"), | |
| gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
| ], | |
| outputs=gr.Textbox(label="Completed Text"), | |
| title="Mocha Sentence Completion", | |
| description="Enter a prompt and get AI completions from your model." | |
| ).launch() | |