Spaces:
Paused
Paused
| from threading import Thread | |
| from transformers import TextStreamer, TextIteratorStreamer | |
| from unsloth import FastLanguageModel | |
| import torch | |
| import gradio as gr | |
| max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! | |
| dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | |
| load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. | |
| model_name = "Danielrahmai1991/llama32_ganjoor_adapt_basic_model_16bit_v1" | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = model_name, | |
| max_seq_length = max_seq_length, | |
| dtype = dtype, | |
| load_in_4bit = load_in_4bit, | |
| trust_remote_code=True, | |
| # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| print("model loaded") | |
| import re | |
| from deep_translator import (GoogleTranslator, | |
| PonsTranslator, | |
| LingueeTranslator, | |
| MyMemoryTranslator, | |
| YandexTranslator, | |
| DeeplTranslator, | |
| QcriTranslator, | |
| single_detection, | |
| batch_detection) | |
| from pyaspeller import YandexSpeller | |
| def error_correct_pyspeller(sample_text): | |
| """ grammer correction of input text""" | |
| speller = YandexSpeller() | |
| fixed = speller.spelled(sample_text) | |
| return fixed | |
| def postprocerssing(inp_text: str): | |
| """Post preocessing of the llm response""" | |
| inp_text = re.sub('<[^>]+>', '', inp_text) | |
| inp_text = inp_text.split('##', 1)[0] | |
| inp_text = error_correct_pyspeller(inp_text) | |
| return inp_text | |
| def llm_run(prompt, max_length, top_p, temprature, top_k, messages): | |
| print("prompt, max_length, top_p, temprature, top_k, messages", prompt, max_length, top_p, temprature, top_k, messages) | |
| lang = single_detection(prompt, api_key='4ab77f25578d450f0902fb42c66d5e11') | |
| if lang == 'en': | |
| prompt = error_correct_pyspeller(prompt) | |
| en_translated = GoogleTranslator(source='auto', target='en').translate(prompt) | |
| messages.append({"role": "user", "content": en_translated}) | |
| # messages.append({"role": "user", "content": prompt}) | |
| print("messages", messages) | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt = True, | |
| return_tensors = "pt", | |
| ) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = dict( | |
| max_length=int(max_length),top_p=float(top_p), do_sample=True, | |
| top_k=int(top_k), streamer=streamer, temperature=float(temprature), repetition_penalty=1.2 | |
| ) | |
| t = Thread(target=model.generate, args=(input_ids,), kwargs=generate_kwargs) | |
| t.start() | |
| generated_text=[] | |
| for text in streamer: | |
| generated_text.append(text) | |
| print('generated_text: ', generated_text) | |
| # yield "".join(generated_text) | |
| yield GoogleTranslator(source='auto', target=lang).translate("".join(generated_text)) | |
| messages.append({"role": "assistant", "content": "".join(generated_text)}) | |
| def clear_memory(messages): | |
| messages.clear() | |
| return "Memory cleaned." | |
| with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.orange, secondary_hue=gr.themes.colors.pink)) as demo: | |
| stored_message = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_text = gr.Textbox(lines=7, label="Prompt", scale=2) | |
| with gr.Row(): | |
| btn1 = gr.Button("Submit", scale=1) | |
| btn2 = gr.Button("Clear", scale=1) | |
| btn3 = gr.Button("Clean Memory", scale=2) | |
| with gr.Column(scale=2): | |
| out_text = gr.Text(lines=15, label="Output", scale=2) | |
| btn1.click(fn=llm_run, inputs=[ | |
| prompt_text, | |
| gr.Textbox(label="Max-Lenth generation", value=500), | |
| gr.Slider(0.0, 1.0, label="Top-P value", value=0.90), | |
| gr.Slider(0.0, 1.0, label="Temprature value", value=0.65), | |
| gr.Textbox(label="Top-K", value=50,), | |
| stored_message | |
| ], outputs=out_text) | |
| btn2.click(lambda: [None, None], outputs=[prompt_text, out_text]) | |
| btn3.click(fn=clear_memory, inputs=[stored_message], outputs=[out_text]) | |
| # demo = gr.Interface(fn=llm_run, inputs=["text"], outputs="text") | |
| demo.launch(debug=True, share=True) |