| |
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from transformers import __version__ as transformers_version |
|
|
| MODEL_NAME = "kvn420/Tenro_V4.1" |
| RECOMMENDED_TRANSFORMERS_VERSION = "4.37.0" |
|
|
| print(f"Version de Transformers : {transformers_version}") |
| if transformers_version < RECOMMENDED_TRANSFORMERS_VERSION: |
| print(f"Attention : Version Transformers ({transformers_version}) < Recommandée ({RECOMMENDED_TRANSFORMERS_VERSION}). Mettez à jour.") |
|
|
| |
| print(f"Chargement du tokenizer pour : {MODEL_NAME}") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_NAME, |
| trust_remote_code=True |
| ) |
| print("Tokenizer chargé.") |
|
|
| print(f"Chargement du modèle : {MODEL_NAME}") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| device_map="auto" |
| ) |
| print(f"Modèle chargé sur {model.device}.") |
|
|
| |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
| print(f"tokenizer.pad_token_id défini sur eos_token_id: {tokenizer.eos_token_id}") |
|
|
|
|
| except Exception as e: |
| print(f"Erreur critique lors du chargement du modèle/tokenizer : {e}") |
| |
| raise gr.Error(f"Impossible de charger le modèle ou le tokenizer: {e}. Vérifiez les logs du Space.") |
| |
|
|
| def chat_interaction(user_input, history): |
| """ |
| Fonction appelée par Gradio pour chaque interaction de chat. |
| history est une liste de paires [user_message, assistant_message] |
| """ |
| if model is None or tokenizer is None: |
| return "Erreur: Modèle ou tokenizer non initialisé." |
|
|
| |
| messages_for_template = [] |
| |
| |
| |
| |
| |
|
|
| for user_msg, assistant_msg in history: |
| messages_for_template.append({"role": "user", "content": user_msg}) |
| messages_for_template.append({"role": "assistant", "content": assistant_msg}) |
| messages_for_template.append({"role": "user", "content": user_input}) |
|
|
| try: |
| prompt_tokenized = tokenizer.apply_chat_template( |
| messages_for_template, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt" |
| ).to(model.device) |
|
|
| outputs = model.generate( |
| prompt_tokenized, |
| max_new_tokens=512, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
|
|
| response_text = tokenizer.decode(outputs[0][prompt_tokenized.shape[-1]:], skip_special_tokens=True) |
|
|
| |
| response_text = response_text.replace("<|im_end|>", "").strip() |
| if response_text.startswith("assistant\n"): |
| response_text = response_text.split("assistant\n", 1)[-1].strip() |
|
|
| return response_text |
|
|
| except Exception as e: |
| print(f"Erreur pendant la génération : {e}") |
| return f"Désolé, une erreur est survenue : {e}" |
|
|
| |
| |
| iface = gr.ChatInterface( |
| fn=chat_interaction, |
| title=f"Chat avec {MODEL_NAME}", |
| description=f"Interface de démonstration pour le modèle {MODEL_NAME}. Le modèle est hébergé sur Hugging Face et chargé ici.", |
| chatbot=gr.Chatbot(height=600), |
| textbox=gr.Textbox(placeholder="Posez votre question ici...", container=False, scale=7), |
| retry_btn="Réessayer", |
| undo_btn="Annuler", |
| clear_btn="Effacer la conversation", |
| submit_btn="Envoyer" |
| ) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if __name__ == '__main__': |
| iface.launch() |
|
|
|
|