Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import json | |
| BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" | |
| ADAPTER = "AmadouDiarouga/telma-mistral-telecom-gn" | |
| print("Chargement tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| print("Chargement modele base...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| ) | |
| print("Chargement adapteur LoRA...") | |
| model = PeftModel.from_pretrained(base_model, ADAPTER) | |
| model.eval() | |
| print("Modele pret !") | |
| INSTRUCTION_FR = ( | |
| "Tu es TELMA, un assistant vocal IA pour les telecoms en Guinee. " | |
| "Tu aides les clients d'Orange Guinee, MTN et Celcom. " | |
| "Reponds en francais simple, court et direct. " | |
| "Cite toujours le code USSD ou le numero quand c'est pertinent." | |
| ) | |
| INSTRUCTION_PL = ( | |
| "Be njahii e on ngondi e telecoms Guinee. " | |
| "Jaabii dow e Pulaar, yitere e laawol." | |
| ) | |
| def repondre(question, langue="fr"): | |
| if not question.strip(): | |
| return json.dumps({"error": "question vide"}) | |
| instruction = INSTRUCTION_PL if langue == "poular" else INSTRUCTION_FR | |
| prompt = f"<s>[INST] {instruction}\n\n{question.strip()} [/INST]" | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| temperature=0.3, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| reponse = response.split("[/INST]")[-1].strip() | |
| return json.dumps({ | |
| "question": question, | |
| "reponse": reponse, | |
| "langue": langue, | |
| "status": "ok" | |
| }, ensure_ascii=False) | |
| demo = gr.Interface( | |
| fn=repondre, | |
| inputs=[ | |
| gr.Textbox(label="Question", placeholder="Comment verifier mon solde Orange ?"), | |
| gr.Dropdown(choices=["fr", "poular"], value="fr", label="Langue"), | |
| ], | |
| outputs=gr.Textbox(label="Reponse JSON"), | |
| title="TELMA - Assistant Telecoms Guinee", | |
| description="API vocale pour Orange, MTN et Celcom en Guinee", | |
| examples=[ | |
| ["Comment verifier mon solde Orange ?", "fr"], | |
| ["Je veux internet MTN pour une semaine", "fr"], | |
| ["Hol balansi Orange am ?", "poular"], | |
| ], | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |