File size: 8,998 Bytes
7c91070
 
b8b0952
 
 
 
 
 
7c91070
b8b0952
 
7c91070
b8b0952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c91070
 
b8b0952
 
 
 
 
 
 
 
 
 
 
 
 
7c91070
b8b0952
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from fastapi import FastAPI
from pydantic import BaseModel
import json
import os
import requests
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from dotenv import load_dotenv

# --- Load environment variables ---
load_dotenv()

# --- Configuration ---
# Hugging Face Hub IDs for your trained MarianMT models
HF_EN_FR_REPO_ID = "cgpcorpbot/cgp_model_en-fr"
HF_FR_EN_REPO_ID = "cgpcorpbot/cgp_model_fr-en"

# Gemini API configuration
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
# Correct Gemini API URL for generateContent
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"

# --- Load MarianMT models and tokenizers using pipeline ---
# These will be loaded once when the app starts
try:
    # Set device to CPU
    device = torch.device("cpu")

    print("Loading EN->FR model...")
    tokenizer_en_fr = AutoTokenizer.from_pretrained(HF_EN_FR_REPO_ID)
    model_en_fr = AutoModelForSeq2SeqLM.from_pretrained(HF_EN_FR_REPO_ID).to(device)
    # Using pipeline for easier translation
    translator_en_fr = pipeline("translation", model=model_en_fr, tokenizer=tokenizer_en_fr, device=device)
    print(f"✅ MarianMT EN-FR Model loaded from Hugging Face Hub: {HF_EN_FR_REPO_ID} and moved to {device}")

    print("Loading FR->EN model...")
    tokenizer_fr_en = AutoTokenizer.from_pretrained(HF_FR_EN_REPO_ID)
    model_fr_en = AutoModelForSeq2SeqLM.from_pretrained(HF_FR_EN_REPO_ID).to(device)
    # Using pipeline for easier translation
    translator_fr_en = pipeline("translation", model=model_fr_en, tokenizer=tokenizer_fr_en, device=device)
    print(f"✅ MarianMT FR-EN Model loaded from Hugging Face Hub: {HF_FR_EN_REPO_ID} and moved to {device}")

except Exception as e:
    print(f"❌ Failed to load MarianMT models from Hugging Face Hub: {e}")
    raise RuntimeError(f"Failed to load translation models: {e}")

# --- Language Detection (Simplified) ---
def detect_language(text: str) -> str:
    text_lower = text.lower()
    french_keywords = ["le", "la", "les", "un", "une", "des", "est", "sont", "je", "tu", "il", "elle", "nous", "vous", "ils", "elles", "pas", "de", "du", "et", "à", "en", "que", "qui", "quoi", "comment", "où", "quand"]
    
    french_word_count = sum(1 for word in french_keywords if word in text_lower.split())
    
    if french_word_count > 2:
        return "french"
    return "english"

# --- MarianMT Translation with Gemini Fallback ---
def translate_text_with_fallback(text: str, direction: str) -> str:
    """
    Translates text using custom MarianMT models, falling back to Gemini if local model fails.
    `direction` can be "en-fr" or "fr-en".
    """
    if not text.strip():
        return ""

    try:
        if direction == "en-fr":
            # Use the pipeline for translation
            translated_result = translator_en_fr(text, max_length=128)
            return translated_result[0]['translation_text'].strip()
        elif direction == "fr-en":
            # Use the pipeline for translation
            translated_result = translator_fr_en(text, max_length=128)
            return translated_result[0]['translation_text'].strip()
        else:
            return "Invalid translation direction."
    except Exception as e:
        print(f"⚠️ Local MarianMT model failed for {direction}, falling back to Gemini for translation: {e}")
        # Fallback to Gemini for translation if MarianMT fails
        return gemini_translate_for_translation(text, direction)

async def gemini_translate_for_translation(text: str, direction: str) -> str:
    """
    Uses Gemini API for translation if MarianMT fails.
    This is a separate function specifically for translation fallback,
    not for general chatbot responses.
    """
    if not GEMINI_API_KEY:
        print("❌ Gemini API Key is missing for translation fallback.")
        return "API key missing for translation."

    target_lang = "French" if direction == "en-fr" else "English"
    # Prompt Gemini to perform translation
    prompt = f"Translate the following text to {target_lang}: \"{text}\""

    chat_history = [{"role": "user", "parts": [{"text": prompt}]}]
    payload = {"contents": chat_history}
    headers = {'Content-Type': 'application/json'}
    api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}"

    try:
        response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload))
        response.raise_for_status()
        result = response.json()
        return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "Translation failed via Gemini.")
    except Exception as e:
        print(f"Gemini API translation fallback error: {e}")
        return "Gemini API translation error."

# --- Main Gemini API Call for Conversational Response ---
async def call_gemini_api_for_response(prompt: str) -> str:
    """
    Calls the Gemini API to get a conversational response in English.
    This is the primary function for generating chatbot responses.
    """
    if not GEMINI_API_KEY:
        print("❌ Gemini API Key is missing for main response.")
        return "API key missing."

    chat_history = []
    # Gemini will always be prompted in English for consistent behavior
    gemini_prompt = f"Answer the following question in English: {prompt}"

    chat_history.append({"role": "user", "parts": [{"text": gemini_prompt}]})
    payload = {"contents": chat_history}

    headers = {'Content-Type': 'application/json'}
    api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}"

    try:
        response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload))
        response.raise_for_status()
        result = response.json()

        return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "No response from Gemini.")
    except Exception as e:
        print(f"Gemini API error for main response: {e}")
        return "Gemini API error for main response."

# --- Main Chatbot Response Logic ---
async def get_multilingual_chatbot_response(user_input: str) -> str:
    """
    Generates a chatbot response using MarianMT for translation and Gemini for core logic.
    Handles language detection, translation, Gemini interaction, and translation back.
    """
    detected_lang = detect_language(user_input)
    print(f"Detected language: {detected_lang.upper()}")

    english_query = user_input
    if detected_lang == "french":
        print("Translating French query to English...")
        # Use the translation function with fallback
        english_query = await translate_text_with_fallback(user_input, "fr-en")
        print(f"Translated query (EN): {english_query}")
        if not english_query.strip() or english_query == "API key missing for translation." or english_query == "Gemini API translation error.":
            # If translation fails or is empty, use original input for Gemini
            english_query = user_input
            print("French to English translation resulted in empty string or error, using original input for Gemini.")

    # Get conversational response from Gemini (always in English)
    gemini_response_en = await call_gemini_api_for_response(english_query)
    print(f"Gemini response (EN): {gemini_response_en}")

    final_response = gemini_response_en
    # Only translate back if original input was French AND Gemini provided a valid response
    if detected_lang == "french" and gemini_response_en not in ["API key missing for main response.", "Gemini API error for main response.", "No response from Gemini."]:
        print("Translating English response back to French...")
        # Use the translation function with fallback
        translated_back_fr = await translate_text_with_fallback(gemini_response_en, "en-fr")
        if translated_back_fr.strip() and translated_back_fr not in ["API key missing for translation.", "Gemini API translation error."]:
            final_response = translated_back_fr
        else:
            print("English to French translation resulted in empty string or error, using English Gemini response.")
            final_response = gemini_response_en # Fallback to English if translation back fails

    return final_response

# --- FastAPI Application ---
app = FastAPI()

# Define the request body model
class ChatRequest(BaseModel):
    user_input: str

@app.get("/")
async def root():
    return {"message": "Multilingual Chatbot API is running. Use /chat endpoint."}

@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
    """
    Endpoint for the multilingual chatbot.
    Receives user input, detects language, translates, gets Gemini response,
    and translates back if necessary.
    """
    user_input = request.user_input
    if not user_input:
        return {"response": "Please provide some input."}

    response = await get_multilingual_chatbot_response(user_input)
    return {"response": response}