abhinav0231 commited on
Commit
2e2d598
Β·
verified Β·
1 Parent(s): 941449c

Update llm.py

Browse files
Files changed (1) hide show
  1. llm.py +104 -116
llm.py CHANGED
@@ -1,116 +1,104 @@
1
- import os
2
- from langdetect import detect, LangDetectException
3
- from llm_setup import llm
4
- from prompts import get_story_prompt
5
- from rag_agent import run_rag_agent
6
- from audio_transcription import transcribe_audio_with_auto_detect
7
- from langchain.prompts import PromptTemplate
8
- from langchain_core.output_parsers import StrOutputParser
9
- import pycountry
10
-
11
- if not llm:
12
- raise ImportError("LLM could not be loaded.")
13
-
14
- def detect_language(text: str) -> str:
15
- """Detects the language of a given text and returns its ISO 639-1 code."""
16
- try:
17
- return detect(text)
18
- except LangDetectException:
19
- print("Warning: Could not detect language, defaulting to English (en).")
20
- return "en"
21
-
22
- def get_language_name(code: str) -> str:
23
- """Converts a language code (e.g., 'en') to its full name (e.g., 'English')."""
24
- try:
25
- return pycountry.languages.get(alpha_2=code).name
26
- except AttributeError:
27
- return "English"
28
-
29
- def detect_target_language(user_prompt: str, input_language_name: str) -> str:
30
- """Uses the LLM to determine the desired output language from the user's prompt."""
31
- print("--- Detecting target language from prompt... ---")
32
- prompt_template = PromptTemplate.from_template(
33
- """Analyze the user's request below. Your task is to determine the desired output language for a story.
34
- - If the user explicitly mentions a language, return that language name.
35
- - If the user does NOT explicitly mention an output language, assume they want the story in their input language.
36
- - Your response MUST be only the name of the language (e.g., "Hindi", "English").
37
- Input Language: "{input_language}"
38
- User's Request: "{prompt}"
39
- Output Language:"""
40
- )
41
- chain = prompt_template | llm | StrOutputParser()
42
- try:
43
- detected_language = chain.invoke({"prompt": user_prompt, "input_language": input_language_name})
44
- print(f"βœ… LLM detected target language: {detected_language.strip()}")
45
- return detected_language.strip()
46
- except Exception as e:
47
- print(f"Warning: Could not detect target language: {e}. Defaulting to input language.")
48
- return input_language_name
49
-
50
- def generate_story(user_prompt: str, story_style: str, audio_file_path: str = None, doc_file_path: str = None) -> str:
51
- """
52
- Orchestrates the full story generation pipeline, only running RAG if a document is provided.
53
- """
54
- print("--- Starting Story Generation Pipeline ---")
55
-
56
- input_lang_code = "en"
57
- if audio_file_path:
58
- print("Input source: Audio file")
59
- transcription_result = transcribe_audio_with_auto_detect(audio_file_path)
60
- user_prompt = transcription_result.get("text")
61
- detected_lang_code = transcription_result.get("detected_language")
62
- if not user_prompt or "Error:" in user_prompt:
63
- return f"Could not generate story due to a transcription error: {user_prompt}"
64
- if detected_lang_code:
65
- input_lang_code = detected_lang_code.split('-')[0]
66
- else:
67
- input_lang_code = detect_language(user_prompt)
68
-
69
- input_language_name = get_language_name(input_lang_code)
70
- print(f"Detected Input Language: {input_language_name} ({input_lang_code})")
71
-
72
- target_language = detect_target_language(user_prompt, input_language_name)
73
- rag_context = ""
74
-
75
- if doc_file_path:
76
- print("\n[Step 1/3] Document provided. Retrieving context with RAG Agent...")
77
- rag_context = run_rag_agent(user_prompt, file_path=doc_file_path)
78
- print("βœ… Context retrieval complete.")
79
- else:
80
- print("\n[Step 1/3] No document provided. Skipping RAG.")
81
-
82
- # The rest of the pipeline proceeds as normal, with or without context.
83
- print("\n[Step 2/3] Engineering the final prompt...")
84
- print(f"πŸ“‹ Story style received: '{story_style}'")
85
- final_prompt = get_story_prompt(user_prompt, story_style, target_language, rag_context)
86
- print("βœ… Prompt engineering complete.")
87
-
88
- # Add this debug check
89
- if final_prompt is None:
90
- print(f"❌ ERROR: get_story_prompt returned None for style: '{story_style}'")
91
- return f"Error: Invalid story style '{story_style}'. Please select a valid style."
92
-
93
- print(f"βœ… Prompt engineering complete. Prompt length: {len(final_prompt)} characters")
94
-
95
- print("\n[Step 3/3] Calling the LLM to generate the story...")
96
- try:
97
- response = llm.invoke(final_prompt)
98
- story = response.content
99
- print("βœ… Story generation complete.")
100
- except Exception as e:
101
- print(f"❌ An error occurred while calling the LLM: {e}")
102
- story = f"Error: Could not generate the story. LLM Error: {str(e)}"
103
-
104
- return story
105
-
106
- # if __name__ == '__main__':
107
- # prompt = "write a story about animals and how everybody lived in peace and harmony"
108
-
109
- # print("--- RUNNING CROSS-LANGUAGE TEST CASE ---")
110
- # generated_story = generate_story(
111
- # user_prompt=prompt,
112
- # story_style="Indian Wisdom"
113
- # )
114
-
115
- # print("\n\n--- GENERATED STORY ---")
116
- # print(generated_story)
 
1
+ import os
2
+ from langdetect import detect, LangDetectException
3
+ from llm_setup import llm
4
+ from prompts import get_story_prompt
5
+ from rag_agent import run_rag_agent
6
+ from audio_transcription import transcribe_audio_with_auto_detect
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ import pycountry
10
+
11
+ if not llm:
12
+ raise ImportError("LLM could not be loaded.")
13
+
14
+ def detect_language(text: str) -> str:
15
+ """Detects the language of a given text and returns its ISO 639-1 code."""
16
+ try:
17
+ return detect(text)
18
+ except LangDetectException:
19
+ print("Warning: Could not detect language, defaulting to English (en).")
20
+ return "en"
21
+
22
+ def get_language_name(code: str) -> str:
23
+ """Converts a language code (e.g., 'en') to its full name (e.g., 'English')."""
24
+ try:
25
+ return pycountry.languages.get(alpha_2=code).name
26
+ except AttributeError:
27
+ return "English"
28
+
29
+ def detect_target_language(user_prompt: str, input_language_name: str) -> str:
30
+ """Uses the LLM to determine the desired output language from the user's prompt."""
31
+ print("--- Detecting target language from prompt... ---")
32
+ prompt_template = PromptTemplate.from_template(
33
+ """Analyze the user's request below. Your task is to determine the desired output language for a story.
34
+ - If the user explicitly mentions a language, return that language name.
35
+ - If the user does NOT explicitly mention an output language, assume they want the story in their input language.
36
+ - Your response MUST be only the name of the language (e.g., "Hindi", "English").
37
+ Input Language: "{input_language}"
38
+ User's Request: "{prompt}"
39
+ Output Language:"""
40
+ )
41
+ chain = prompt_template | llm | StrOutputParser()
42
+ try:
43
+ detected_language = chain.invoke({"prompt": user_prompt, "input_language": input_language_name})
44
+ print(f"βœ… LLM detected target language: {detected_language.strip()}")
45
+ return detected_language.strip()
46
+ except Exception as e:
47
+ print(f"Warning: Could not detect target language: {e}. Defaulting to input language.")
48
+ return input_language_name
49
+
50
+ def generate_story(user_prompt: str, story_style: str, audio_file_path: str = None, doc_file_path: str = None) -> str:
51
+ """
52
+ Orchestrates the full story generation pipeline, only running RAG if a document is provided.
53
+ """
54
+ print("--- Starting Story Generation Pipeline ---")
55
+
56
+ input_lang_code = "en"
57
+ if audio_file_path:
58
+ print("Input source: Audio file")
59
+ transcription_result = transcribe_audio_with_auto_detect(audio_file_path)
60
+ user_prompt = transcription_result.get("text")
61
+ detected_lang_code = transcription_result.get("detected_language")
62
+ if not user_prompt or "Error:" in user_prompt:
63
+ return f"Could not generate story due to a transcription error: {user_prompt}"
64
+ if detected_lang_code:
65
+ input_lang_code = detected_lang_code.split('-')[0]
66
+ else:
67
+ input_lang_code = detect_language(user_prompt)
68
+
69
+ input_language_name = get_language_name(input_lang_code)
70
+ print(f"Detected Input Language: {input_language_name} ({input_lang_code})")
71
+
72
+ target_language = detect_target_language(user_prompt, input_language_name)
73
+ rag_context = ""
74
+
75
+ if doc_file_path:
76
+ print("\n[Step 1/3] Document provided. Retrieving context with RAG Agent...")
77
+ rag_context = run_rag_agent(user_prompt, file_path=doc_file_path)
78
+ print("βœ… Context retrieval complete.")
79
+ else:
80
+ print("\n[Step 1/3] No document provided. Skipping RAG.")
81
+
82
+ # The rest of the pipeline proceeds as normal, with or without context.
83
+ print("\n[Step 2/3] Engineering the final prompt...")
84
+ print(f"πŸ“‹ Story style received: '{story_style}'")
85
+ final_prompt = get_story_prompt(user_prompt, story_style, target_language, rag_context)
86
+ print("βœ… Prompt engineering complete.")
87
+
88
+ # Add this debug check
89
+ if final_prompt is None:
90
+ print(f"❌ ERROR: get_story_prompt returned None for style: '{story_style}'")
91
+ return f"Error: Invalid story style '{story_style}'. Please select a valid style."
92
+
93
+ print(f"βœ… Prompt engineering complete. Prompt length: {len(final_prompt)} characters")
94
+
95
+ print("\n[Step 3/3] Calling the LLM to generate the story...")
96
+ try:
97
+ response = llm.invoke(final_prompt)
98
+ story = response.content
99
+ print("βœ… Story generation complete.")
100
+ except Exception as e:
101
+ print(f"❌ An error occurred while calling the LLM: {e}")
102
+ story = f"Error: Could not generate the story. LLM Error: {str(e)}"
103
+
104
+ return story