Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import logging | |
| from flask import Flask, request, jsonify | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from flores200_codes import flores_codes | |
| app = Flask(__name__) | |
| logging.basicConfig(level=logging.DEBUG) | |
| def load_models(): | |
| model_name_dict = {"nllb-distilled-600M": "facebook/nllb-200-distilled-600M"} | |
| model_dict = {} | |
| for call_name, real_name in model_name_dict.items(): | |
| logging.info(f"\tLoading model: {call_name}") | |
| model = AutoModelForSeq2SeqLM.from_pretrained(real_name) | |
| tokenizer = AutoTokenizer.from_pretrained(real_name) | |
| model_dict[call_name + "_model"] = model | |
| model_dict[call_name + "_tokenizer"] = tokenizer | |
| return model_dict | |
| global model_dict | |
| model_dict = load_models() | |
| def translate_text(): | |
| data = request.json | |
| logging.debug(f"Received data: {data}") | |
| source_lang = data.get("source") | |
| target_lang = data.get("target") | |
| input_text = data.get("text") | |
| if not source_lang or not target_lang or not input_text: | |
| logging.error("Missing fields in the request") | |
| return jsonify({"error": "source, target, and text fields are required"}), 400 | |
| model_name = "nllb-distilled-600M" | |
| start_time = time.time() | |
| source = flores_codes.get(source_lang) | |
| target = flores_codes.get(target_lang) | |
| if not source or not target: | |
| logging.error("Invalid source or target language code") | |
| return jsonify({"error": "Invalid source or target language code"}), 400 | |
| model = model_dict[model_name + "_model"] | |
| tokenizer = model_dict[model_name + "_tokenizer"] | |
| translator = pipeline( | |
| "translation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| src_lang=source, | |
| tgt_lang=target, | |
| ) | |
| output = translator(input_text, max_length=400) | |
| end_time = time.time() | |
| output_text = output[0]["translation_text"] | |
| result = { | |
| "inference_time": end_time - start_time, | |
| "source": source_lang, | |
| "target": target_lang, | |
| "result": output_text, | |
| } | |
| logging.debug(f"Translation result: {result}") | |
| return jsonify(result) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=5000, debug=True) | |