Spaces:
Sleeping
Sleeping
File size: 8,998 Bytes
7c91070 b8b0952 7c91070 b8b0952 7c91070 b8b0952 7c91070 b8b0952 7c91070 b8b0952 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
from fastapi import FastAPI
from pydantic import BaseModel
import json
import os
import requests
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from dotenv import load_dotenv
# --- Load environment variables ---
load_dotenv()
# --- Configuration ---
# Hugging Face Hub IDs for your trained MarianMT models
HF_EN_FR_REPO_ID = "cgpcorpbot/cgp_model_en-fr"
HF_FR_EN_REPO_ID = "cgpcorpbot/cgp_model_fr-en"
# Gemini API configuration
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
# Correct Gemini API URL for generateContent
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
# --- Load MarianMT models and tokenizers using pipeline ---
# These will be loaded once when the app starts
try:
# Set device to CPU
device = torch.device("cpu")
print("Loading EN->FR model...")
tokenizer_en_fr = AutoTokenizer.from_pretrained(HF_EN_FR_REPO_ID)
model_en_fr = AutoModelForSeq2SeqLM.from_pretrained(HF_EN_FR_REPO_ID).to(device)
# Using pipeline for easier translation
translator_en_fr = pipeline("translation", model=model_en_fr, tokenizer=tokenizer_en_fr, device=device)
print(f"✅ MarianMT EN-FR Model loaded from Hugging Face Hub: {HF_EN_FR_REPO_ID} and moved to {device}")
print("Loading FR->EN model...")
tokenizer_fr_en = AutoTokenizer.from_pretrained(HF_FR_EN_REPO_ID)
model_fr_en = AutoModelForSeq2SeqLM.from_pretrained(HF_FR_EN_REPO_ID).to(device)
# Using pipeline for easier translation
translator_fr_en = pipeline("translation", model=model_fr_en, tokenizer=tokenizer_fr_en, device=device)
print(f"✅ MarianMT FR-EN Model loaded from Hugging Face Hub: {HF_FR_EN_REPO_ID} and moved to {device}")
except Exception as e:
print(f"❌ Failed to load MarianMT models from Hugging Face Hub: {e}")
raise RuntimeError(f"Failed to load translation models: {e}")
# --- Language Detection (Simplified) ---
def detect_language(text: str) -> str:
text_lower = text.lower()
french_keywords = ["le", "la", "les", "un", "une", "des", "est", "sont", "je", "tu", "il", "elle", "nous", "vous", "ils", "elles", "pas", "de", "du", "et", "à", "en", "que", "qui", "quoi", "comment", "où", "quand"]
french_word_count = sum(1 for word in french_keywords if word in text_lower.split())
if french_word_count > 2:
return "french"
return "english"
# --- MarianMT Translation with Gemini Fallback ---
def translate_text_with_fallback(text: str, direction: str) -> str:
"""
Translates text using custom MarianMT models, falling back to Gemini if local model fails.
`direction` can be "en-fr" or "fr-en".
"""
if not text.strip():
return ""
try:
if direction == "en-fr":
# Use the pipeline for translation
translated_result = translator_en_fr(text, max_length=128)
return translated_result[0]['translation_text'].strip()
elif direction == "fr-en":
# Use the pipeline for translation
translated_result = translator_fr_en(text, max_length=128)
return translated_result[0]['translation_text'].strip()
else:
return "Invalid translation direction."
except Exception as e:
print(f"⚠️ Local MarianMT model failed for {direction}, falling back to Gemini for translation: {e}")
# Fallback to Gemini for translation if MarianMT fails
return gemini_translate_for_translation(text, direction)
async def gemini_translate_for_translation(text: str, direction: str) -> str:
"""
Uses Gemini API for translation if MarianMT fails.
This is a separate function specifically for translation fallback,
not for general chatbot responses.
"""
if not GEMINI_API_KEY:
print("❌ Gemini API Key is missing for translation fallback.")
return "API key missing for translation."
target_lang = "French" if direction == "en-fr" else "English"
# Prompt Gemini to perform translation
prompt = f"Translate the following text to {target_lang}: \"{text}\""
chat_history = [{"role": "user", "parts": [{"text": prompt}]}]
payload = {"contents": chat_history}
headers = {'Content-Type': 'application/json'}
api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}"
try:
response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload))
response.raise_for_status()
result = response.json()
return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "Translation failed via Gemini.")
except Exception as e:
print(f"Gemini API translation fallback error: {e}")
return "Gemini API translation error."
# --- Main Gemini API Call for Conversational Response ---
async def call_gemini_api_for_response(prompt: str) -> str:
"""
Calls the Gemini API to get a conversational response in English.
This is the primary function for generating chatbot responses.
"""
if not GEMINI_API_KEY:
print("❌ Gemini API Key is missing for main response.")
return "API key missing."
chat_history = []
# Gemini will always be prompted in English for consistent behavior
gemini_prompt = f"Answer the following question in English: {prompt}"
chat_history.append({"role": "user", "parts": [{"text": gemini_prompt}]})
payload = {"contents": chat_history}
headers = {'Content-Type': 'application/json'}
api_url_with_key = f"{GEMINI_API_URL}?key={GEMINI_API_KEY}"
try:
response = requests.post(api_url_with_key, headers=headers, data=json.dumps(payload))
response.raise_for_status()
result = response.json()
return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "No response from Gemini.")
except Exception as e:
print(f"Gemini API error for main response: {e}")
return "Gemini API error for main response."
# --- Main Chatbot Response Logic ---
async def get_multilingual_chatbot_response(user_input: str) -> str:
"""
Generates a chatbot response using MarianMT for translation and Gemini for core logic.
Handles language detection, translation, Gemini interaction, and translation back.
"""
detected_lang = detect_language(user_input)
print(f"Detected language: {detected_lang.upper()}")
english_query = user_input
if detected_lang == "french":
print("Translating French query to English...")
# Use the translation function with fallback
english_query = await translate_text_with_fallback(user_input, "fr-en")
print(f"Translated query (EN): {english_query}")
if not english_query.strip() or english_query == "API key missing for translation." or english_query == "Gemini API translation error.":
# If translation fails or is empty, use original input for Gemini
english_query = user_input
print("French to English translation resulted in empty string or error, using original input for Gemini.")
# Get conversational response from Gemini (always in English)
gemini_response_en = await call_gemini_api_for_response(english_query)
print(f"Gemini response (EN): {gemini_response_en}")
final_response = gemini_response_en
# Only translate back if original input was French AND Gemini provided a valid response
if detected_lang == "french" and gemini_response_en not in ["API key missing for main response.", "Gemini API error for main response.", "No response from Gemini."]:
print("Translating English response back to French...")
# Use the translation function with fallback
translated_back_fr = await translate_text_with_fallback(gemini_response_en, "en-fr")
if translated_back_fr.strip() and translated_back_fr not in ["API key missing for translation.", "Gemini API translation error."]:
final_response = translated_back_fr
else:
print("English to French translation resulted in empty string or error, using English Gemini response.")
final_response = gemini_response_en # Fallback to English if translation back fails
return final_response
# --- FastAPI Application ---
app = FastAPI()
# Define the request body model
class ChatRequest(BaseModel):
user_input: str
@app.get("/")
async def root():
return {"message": "Multilingual Chatbot API is running. Use /chat endpoint."}
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""
Endpoint for the multilingual chatbot.
Receives user input, detects language, translates, gets Gemini response,
and translates back if necessary.
"""
user_input = request.user_input
if not user_input:
return {"response": "Please provide some input."}
response = await get_multilingual_chatbot_response(user_input)
return {"response": response}
|