Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import gc | |
| class ModelManager: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.model_name = "CohereForAI/c4ai-command-r-plus-4bit" | |
| def load_model(self): | |
| if self.model is None: | |
| try: | |
| print("λͺ¨λΈ λ‘λ© μ€... μκ°μ΄ 걸릴 μ μμ΅λλ€.") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| load_in_4bit=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| print("λͺ¨λΈ λ‘λ© μλ£!") | |
| return True | |
| except Exception as e: | |
| print(f"λͺ¨λΈ λ‘λ© μ€ν¨: {e}") | |
| return False | |
| return True | |
| def generate(self, message, history, max_tokens=1000, temperature=0.7): | |
| if not self.load_model(): | |
| return "λͺ¨λΈ λ‘λ©μ μ€ν¨νμ΅λλ€." | |
| try: | |
| # μ±ν νμ€ν 리 κ΅¬μ± | |
| conversation = [] | |
| for human, assistant in history: | |
| conversation.append({"role": "user", "content": human}) | |
| if assistant: | |
| conversation.append({"role": "assistant", "content": assistant}) | |
| conversation.append({"role": "user", "content": message}) | |
| # ν ν°ν | |
| input_ids = self.tokenizer.apply_chat_template( | |
| conversation, | |
| return_tensors="pt", | |
| add_generation_prompt=True | |
| ) | |
| if torch.cuda.is_available(): | |
| input_ids = input_ids.to("cuda") | |
| # μμ± | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| input_ids, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode( | |
| outputs[0][input_ids.shape[-1]:], | |
| skip_special_tokens=True | |
| ) | |
| return response | |
| except Exception as e: | |
| return f"μμ± μ€ μ€λ₯ λ°μ: {str(e)}" | |
| finally: | |
| # λ©λͺ¨λ¦¬ μ 리 | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # λͺ¨λΈ λ§€λμ μΈμ€ν΄μ€ | |
| model_manager = ModelManager() | |
| def chat_fn(message, history, max_tokens, temperature): | |
| if not message.strip(): | |
| return history, "" | |
| # μ¬μ©μ λ©μμ§ μΆκ° | |
| history.append([message, "μμ± μ€..."]) | |
| # λ΄ μλ΅ μμ± | |
| response = model_manager.generate(message, history[:-1], max_tokens, temperature) | |
| history[-1][1] = response | |
| return history, "" | |
| # Gradio μΈν°νμ΄μ€ | |
| with gr.Blocks(title="Command R+ Chat") as demo: | |
| gr.Markdown(""" | |
| # π€ Command R+ 4bit μ±ν λ΄ | |
| Cohereμ Command R+ 4bit μμν λͺ¨λΈκ³Ό λνν μ μμ΅λλ€. | |
| β οΈ μ²« μ€ν μ λͺ¨λΈ λ‘λ©μ μκ°μ΄ 걸릴 μ μμ΅λλ€. | |
| """) | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| show_label=False, | |
| show_copy_button=True | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="λ©μμ§ μ λ ₯", | |
| placeholder="Command R+μκ² μ§λ¬ΈνμΈμ...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| submit = gr.Button("μ μ‘ π€", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear = gr.Button("λν μ΄κΈ°ν ποΈ") | |
| with gr.Accordion("κ³ κΈ μ€μ ", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=100, | |
| maximum=2000, | |
| value=1000, | |
| step=100, | |
| label="μ΅λ ν ν° μ" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature (μ°½μμ±)" | |
| ) | |
| # μ΄λ²€νΈ νΈλ€λ¬ | |
| msg.submit( | |
| chat_fn, | |
| [msg, chatbot, max_tokens, temperature], | |
| [chatbot, msg] | |
| ) | |
| submit.click( | |
| chat_fn, | |
| [msg, chatbot, max_tokens, temperature], | |
| [chatbot, msg] | |
| ) | |
| clear.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
| if __name__ == "__main__": | |
| demo.launch() |