Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import json | |
| import os | |
| import requests | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| from dotenv import load_dotenv | |
| # --- Load environment variables --- | |
| load_dotenv() | |
| # --- Configuration --- | |
| # Hugging Face Hub IDs for your trained MarianMT models | |
| HF_EN_FR_REPO_ID = "cgpcorpbot/cgp_model_en-fr" | |
| HF_FR_EN_REPO_ID = "cgpcorpbot/cgp_model_fr-en" | |
| # Gemini API configuration | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "") | |
| # Correct Gemini API URL for generateContent | |
| GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" | |
| # --- Load MarianMT models and tokenizers using pipeline --- | |
| # These will be loaded once when the app starts | |
| try: | |
| # Set device to CPU | |
| device = torch.device("cpu") | |
| print("Loading EN->FR model...") | |
| tokenizer_en_fr = AutoTokenizer.from_pretrained(HF_EN_FR_REPO_ID) | |
| model_en_fr = AutoModelForSeq2SeqLM.from_pretrained(HF_EN_FR_REPO_ID).to(device) | |
| # Using pipeline for easier translation | |
| translator_en_fr = pipeline("translation", model=model_en_fr, tokenizer=tokenizer_en_fr, device=device) | |
| print(f"✅ MarianMT EN-FR Model loaded from Hugging Face Hub: {HF_EN_FR_REPO_ID} and moved to {device}") | |
| print("Loading FR->EN model...") | |
| tokenizer_fr_en = AutoTokenizer.from_pretrained(HF_FR_EN_REPO_ID) | |
| model_fr_en = AutoModelForSeq2SeqLM.from_pretrained(HF_FR_EN_REPO_ID).to(device) | |
| # Using pipeline for easier translation | |
| translator_fr_en = pipeline("translation", model=model_fr_en, tokenizer=tokenizer_fr_en, device=device) | |
| print(f"✅ MarianMT FR-EN Model loaded from Hugging Face Hub: {HF_FR_EN_REPO_ID} and moved to {device}") | |
| except Exception as e: | |
| print(f"❌ Failed to load MarianMT models from Hugging Face Hub: {e}") | |
| raise RuntimeError(f"Failed to load translation models: {e}") | |
| # --- Language Detection (Simplified) --- | |
| def detect_language(text: str) -> str: | |
| text_lower = text.lower() | |
| french_keywords = ["le", "la", "les", "un", "une", "des", "est", "sont", "je", "tu", "il", "elle", "nous", "vous", "ils", "elles", "pas", "de", "du", "et", "à", "en", "que", "qui", "quoi", "comment", "où", "quand"] | |
| french_word_count = sum(1 for word in french_keywords if word in text_lower.split()) | |
| if french_word_count > 2: | |
| return "french" | |
| return "english" | |
| # --- MarianMT Translation with Gemini Fallback --- | |
| def translate_text_with_fallback(text: str, direction: str) -> str: | |
| """ | |
| Translates text using custom MarianMT models, falling back to Gemini if local model fails. | |
| `direction` can be "en-fr" or "fr-en". | |
| """ | |
| if not text.strip(): | |
| return "" | |
| try: | |
| if direction == "en-fr": | |
| # Use the pipeline for translation | |
| translated_result = translator_en_fr(text, max_length=128) | |
| return translated_result[0]['translation_text'].strip() | |
| elif direction == "fr-en": | |
| # Use the pipeline for translation | |
| translated_result = translator_fr_en(text, max_length=128) | |
| return translated_result[0]['translation_text'].strip() | |
| else: | |
| return "Invalid translation direction." | |
| except Exception as e: | |
| print(f"⚠️ Local MarianMT model failed for {direction}, falling back to Gemini for translation: {e}") | |
| # Fallback to Gemini for translation if MarianMT fails | |
| return gemini_translate_for_translation(text, direction) | |
| async def gemini_translate_for_translation(text: str, direction: str) -> str: | |
| """ | |
| Uses Gemini API for translation if MarianMT fails. | |
| This is a separate function specifically for translation fallback, | |
| not for general chatbot responses. | |
| """ | |
| if not GEMINI_API_KEY: | |
| print("❌ Gemini API Key is missing for translation fallback.") | |
| return "API key missing for translation." | |
| target_lang = "French" if direction == "en-fr" else "English" | |
| # Prompt Gemini to perform translation | |
| prompt = f"Translate the following text to {target_lang}: \"{text}\"" | |
| chat_history = [{"role": "user", "parts": [{"text": prompt}]}] | |
| payload = {"contents": chat_history} | |
| headers = {'Content-Type': 'application/json'} | |
| api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}" | |
| try: | |
| response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload)) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "Translation failed via Gemini.") | |
| except Exception as e: | |
| print(f"Gemini API translation fallback error: {e}") | |
| return "Gemini API translation error." | |
| # --- Main Gemini API Call for Conversational Response --- | |
| async def call_gemini_api_for_response(prompt: str) -> str: | |
| """ | |
| Calls the Gemini API to get a conversational response in English. | |
| This is the primary function for generating chatbot responses. | |
| """ | |
| if not GEMINI_API_KEY: | |
| print("❌ Gemini API Key is missing for main response.") | |
| return "API key missing." | |
| chat_history = [] | |
| # Gemini will always be prompted in English for consistent behavior | |
| gemini_prompt = f"Answer the following question in English: {prompt}" | |
| chat_history.append({"role": "user", "parts": [{"text": gemini_prompt}]}) | |
| payload = {"contents": chat_history} | |
| headers = {'Content-Type': 'application/json'} | |
| api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}" | |
| try: | |
| response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload)) | |
| response.raise_for_status() | |
| result = response.json() | |
| return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "No response from Gemini.") | |
| except Exception as e: | |
| print(f"Gemini API error for main response: {e}") | |
| return "Gemini API error for main response." | |
| # --- Main Chatbot Response Logic --- | |
| async def get_multilingual_chatbot_response(user_input: str) -> str: | |
| """ | |
| Generates a chatbot response using MarianMT for translation and Gemini for core logic. | |
| Handles language detection, translation, Gemini interaction, and translation back. | |
| """ | |
| detected_lang = detect_language(user_input) | |
| print(f"Detected language: {detected_lang.upper()}") | |
| english_query = user_input | |
| if detected_lang == "french": | |
| print("Translating French query to English...") | |
| # Use the translation function with fallback | |
| english_query = await translate_text_with_fallback(user_input, "fr-en") | |
| print(f"Translated query (EN): {english_query}") | |
| if not english_query.strip() or english_query == "API key missing for translation." or english_query == "Gemini API translation error.": | |
| # If translation fails or is empty, use original input for Gemini | |
| english_query = user_input | |
| print("French to English translation resulted in empty string or error, using original input for Gemini.") | |
| # Get conversational response from Gemini (always in English) | |
| gemini_response_en = await call_gemini_api_for_response(english_query) | |
| print(f"Gemini response (EN): {gemini_response_en}") | |
| final_response = gemini_response_en | |
| # Only translate back if original input was French AND Gemini provided a valid response | |
| if detected_lang == "french" and gemini_response_en not in ["API key missing for main response.", "Gemini API error for main response.", "No response from Gemini."]: | |
| print("Translating English response back to French...") | |
| # Use the translation function with fallback | |
| translated_back_fr = await translate_text_with_fallback(gemini_response_en, "en-fr") | |
| if translated_back_fr.strip() and translated_back_fr not in ["API key missing for translation.", "Gemini API translation error."]: | |
| final_response = translated_back_fr | |
| else: | |
| print("English to French translation resulted in empty string or error, using English Gemini response.") | |
| final_response = gemini_response_en # Fallback to English if translation back fails | |
| return final_response | |
| # --- FastAPI Application --- | |
| app = FastAPI() | |
| # Define the request body model | |
| class ChatRequest(BaseModel): | |
| user_input: str | |
| async def root(): | |
| return {"message": "Multilingual Chatbot API is running. Use /chat endpoint."} | |
| async def chat_endpoint(request: ChatRequest): | |
| """ | |
| Endpoint for the multilingual chatbot. | |
| Receives user input, detects language, translates, gets Gemini response, | |
| and translates back if necessary. | |
| """ | |
| user_input = request.user_input | |
| if not user_input: | |
| return {"response": "Please provide some input."} | |
| response = await get_multilingual_chatbot_response(user_input) | |
| return {"response": response} | |