Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import os | |
| import sys | |
| import time | |
| import json | |
| from typing import List | |
| from transformers import ( | |
| LlamaTokenizer, | |
| LlamaForCausalLM, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| LlamaConfig | |
| ) | |
| from peft import PeftModel | |
| from accelerate import disk_offload | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Johntad110/llama-2-7b-amharic-tokenizer", | |
| return_dict=True, | |
| load_in_8bit=True, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| attn_implementation="sdpa" | |
| ) | |
| tokenizer = LlamaTokenizer.from_pretrained( | |
| "Johntad110/llama-2-7b-amharic-tokenizer" | |
| ) | |
| embedding_size = model.get_input_embeddings().weight.shape[0] | |
| if len(tokenizer) != embedding_size: | |
| model.resize_token_embeddings(len(tokenizer)) | |
| model = PeftModel.from_pretrained(model, "Johntad110/llama-2-amharic-peft") | |
| model.eval() # Set model to evaluation mode | |
| def generate_text( | |
| prompt: str, | |
| max_new_tokens: int = None, | |
| seed: int = 42, | |
| do_sample: bool = True, | |
| min_length: int = None, | |
| use_cache: bool = True, | |
| top_p: float = 1.0, | |
| temperature: float = 1.0, | |
| top_k: int = 1, | |
| repetition_penalty: float = 1.0, | |
| length_penalty: int = 1, | |
| ): | |
| """ | |
| Function to perform text generation with user-defined parameters | |
| """ | |
| torch.cuda.manual_seed(seed) | |
| torch.manual_seed(seed) | |
| batch = tokenizer(prompt, return_tensors="pt") | |
| batch = {k: v.to("cuda") for k, v in batch.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **batch, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| top_p=top_p, | |
| temperature=temperature, | |
| min_length=min_length, | |
| use_cache=use_cache, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| ) | |
| output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return output_text | |
| interface = gr.Interface( | |
| fn=generate_text, | |
| inputs=[gr.Textbox(label="Prompt")], | |
| outputs="text" | |
| ) | |
| interface.launch(debug=True) | |