from fastapi import FastAPI from pydantic import BaseModel import json import os import requests import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from dotenv import load_dotenv # --- Load environment variables --- load_dotenv() # --- Configuration --- # Hugging Face Hub IDs for your trained MarianMT models HF_EN_FR_REPO_ID = "cgpcorpbot/cgp_model_en-fr" HF_FR_EN_REPO_ID = "cgpcorpbot/cgp_model_fr-en" # Gemini API configuration GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "") # Correct Gemini API URL for generateContent GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" # --- Load MarianMT models and tokenizers using pipeline --- # These will be loaded once when the app starts try: # Set device to CPU device = torch.device("cpu") print("Loading EN->FR model...") tokenizer_en_fr = AutoTokenizer.from_pretrained(HF_EN_FR_REPO_ID) model_en_fr = AutoModelForSeq2SeqLM.from_pretrained(HF_EN_FR_REPO_ID).to(device) # Using pipeline for easier translation translator_en_fr = pipeline("translation", model=model_en_fr, tokenizer=tokenizer_en_fr, device=device) print(f"✅ MarianMT EN-FR Model loaded from Hugging Face Hub: {HF_EN_FR_REPO_ID} and moved to {device}") print("Loading FR->EN model...") tokenizer_fr_en = AutoTokenizer.from_pretrained(HF_FR_EN_REPO_ID) model_fr_en = AutoModelForSeq2SeqLM.from_pretrained(HF_FR_EN_REPO_ID).to(device) # Using pipeline for easier translation translator_fr_en = pipeline("translation", model=model_fr_en, tokenizer=tokenizer_fr_en, device=device) print(f"✅ MarianMT FR-EN Model loaded from Hugging Face Hub: {HF_FR_EN_REPO_ID} and moved to {device}") except Exception as e: print(f"❌ Failed to load MarianMT models from Hugging Face Hub: {e}") raise RuntimeError(f"Failed to load translation models: {e}") # --- Language Detection (Simplified) --- def detect_language(text: str) -> str: text_lower = text.lower() french_keywords = ["le", "la", "les", "un", "une", "des", "est", "sont", "je", "tu", "il", "elle", "nous", "vous", "ils", "elles", "pas", "de", "du", "et", "à", "en", "que", "qui", "quoi", "comment", "où", "quand"] french_word_count = sum(1 for word in french_keywords if word in text_lower.split()) if french_word_count > 2: return "french" return "english" # --- MarianMT Translation with Gemini Fallback --- def translate_text_with_fallback(text: str, direction: str) -> str: """ Translates text using custom MarianMT models, falling back to Gemini if local model fails. `direction` can be "en-fr" or "fr-en". """ if not text.strip(): return "" try: if direction == "en-fr": # Use the pipeline for translation translated_result = translator_en_fr(text, max_length=128) return translated_result[0]['translation_text'].strip() elif direction == "fr-en": # Use the pipeline for translation translated_result = translator_fr_en(text, max_length=128) return translated_result[0]['translation_text'].strip() else: return "Invalid translation direction." except Exception as e: print(f"⚠️ Local MarianMT model failed for {direction}, falling back to Gemini for translation: {e}") # Fallback to Gemini for translation if MarianMT fails return gemini_translate_for_translation(text, direction) async def gemini_translate_for_translation(text: str, direction: str) -> str: """ Uses Gemini API for translation if MarianMT fails. This is a separate function specifically for translation fallback, not for general chatbot responses. """ if not GEMINI_API_KEY: print("❌ Gemini API Key is missing for translation fallback.") return "API key missing for translation." target_lang = "French" if direction == "en-fr" else "English" # Prompt Gemini to perform translation prompt = f"Translate the following text to {target_lang}: \"{text}\"" chat_history = [{"role": "user", "parts": [{"text": prompt}]}] payload = {"contents": chat_history} headers = {'Content-Type': 'application/json'} api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}" try: response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload)) response.raise_for_status() result = response.json() return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "Translation failed via Gemini.") except Exception as e: print(f"Gemini API translation fallback error: {e}") return "Gemini API translation error." # --- Main Gemini API Call for Conversational Response --- async def call_gemini_api_for_response(prompt: str) -> str: """ Calls the Gemini API to get a conversational response in English. This is the primary function for generating chatbot responses. """ if not GEMINI_API_KEY: print("❌ Gemini API Key is missing for main response.") return "API key missing." chat_history = [] # Gemini will always be prompted in English for consistent behavior gemini_prompt = f"Answer the following question in English: {prompt}" chat_history.append({"role": "user", "parts": [{"text": gemini_prompt}]}) payload = {"contents": chat_history} headers = {'Content-Type': 'application/json'} api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}" try: response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload)) response.raise_for_status() result = response.json() return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "No response from Gemini.") except Exception as e: print(f"Gemini API error for main response: {e}") return "Gemini API error for main response." # --- Main Chatbot Response Logic --- async def get_multilingual_chatbot_response(user_input: str) -> str: """ Generates a chatbot response using MarianMT for translation and Gemini for core logic. Handles language detection, translation, Gemini interaction, and translation back. """ detected_lang = detect_language(user_input) print(f"Detected language: {detected_lang.upper()}") english_query = user_input if detected_lang == "french": print("Translating French query to English...") # Use the translation function with fallback english_query = await translate_text_with_fallback(user_input, "fr-en") print(f"Translated query (EN): {english_query}") if not english_query.strip() or english_query == "API key missing for translation." or english_query == "Gemini API translation error.": # If translation fails or is empty, use original input for Gemini english_query = user_input print("French to English translation resulted in empty string or error, using original input for Gemini.") # Get conversational response from Gemini (always in English) gemini_response_en = await call_gemini_api_for_response(english_query) print(f"Gemini response (EN): {gemini_response_en}") final_response = gemini_response_en # Only translate back if original input was French AND Gemini provided a valid response if detected_lang == "french" and gemini_response_en not in ["API key missing for main response.", "Gemini API error for main response.", "No response from Gemini."]: print("Translating English response back to French...") # Use the translation function with fallback translated_back_fr = await translate_text_with_fallback(gemini_response_en, "en-fr") if translated_back_fr.strip() and translated_back_fr not in ["API key missing for translation.", "Gemini API translation error."]: final_response = translated_back_fr else: print("English to French translation resulted in empty string or error, using English Gemini response.") final_response = gemini_response_en # Fallback to English if translation back fails return final_response # --- FastAPI Application --- app = FastAPI() # Define the request body model class ChatRequest(BaseModel): user_input: str @app.get("/") async def root(): return {"message": "Multilingual Chatbot API is running. Use /chat endpoint."} @app.post("/chat") async def chat_endpoint(request: ChatRequest): """ Endpoint for the multilingual chatbot. Receives user input, detects language, translates, gets Gemini response, and translates back if necessary. """ user_input = request.user_input if not user_input: return {"response": "Please provide some input."} response = await get_multilingual_chatbot_response(user_input) return {"response": response}