roshnn24 commited on
Commit
d77a598
·
verified ·
1 Parent(s): f0e214a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -10,6 +10,9 @@ from datetime import datetime
10
  import sqlite3
11
  from contextlib import contextmanager
12
  from werkzeug.utils import secure_filename
 
 
 
13
 
14
  # Initialize Flask application
15
  app = Flask(__name__)
@@ -64,6 +67,36 @@ prompt = PromptTemplate(
64
  template=prompt_template
65
  )
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  @contextmanager
68
  def get_db_connection():
69
  """Context manager for database connections"""
@@ -117,19 +150,8 @@ def initialize_llm():
117
  raise ValueError("HUGGINGFACE_API_TOKEN not found in environment variables")
118
  print("API token found")
119
 
120
- # Initialize with specific task parameters
121
- llm = HuggingFaceHub(
122
- repo_id="mistralai/Mistral-7B-Instruct-v0.1",
123
- huggingfacehub_api_token=api_token,
124
- task="text-generation", # Specify the task
125
- model_kwargs={
126
- "temperature": 0.7,
127
- "max_new_tokens": 512,
128
- "top_p": 0.95,
129
- "repetition_penalty": 1.15,
130
- "return_full_text": False,
131
- }
132
- )
133
 
134
  # Test the LLM
135
  print("Testing LLM with a simple prompt...")
 
10
  import sqlite3
11
  from contextlib import contextmanager
12
  from werkzeug.utils import secure_filename
13
+ from huggingface_hub import InferenceClient
14
+ from langchain.llms.base import LLM
15
+ from typing import Optional, List, Any
16
 
17
  # Initialize Flask application
18
  app = Flask(__name__)
 
67
  template=prompt_template
68
  )
69
 
70
+ class CustomHuggingFaceInference(LLM):
71
+ client: Any
72
+ model: str
73
+
74
+ def __init__(self, token: str):
75
+ super().__init__()
76
+ self.client = InferenceClient(token=token)
77
+ self.model = "mistralai/Mistral-7B-Instruct-v0.1"
78
+
79
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
80
+ response = self.client.text_generation(
81
+ prompt,
82
+ model=self.model,
83
+ max_new_tokens=512,
84
+ temperature=0.7,
85
+ top_p=0.95,
86
+ repetition_penalty=1.15
87
+ )
88
+ return response
89
+
90
+ @property
91
+ def _identifying_params(self):
92
+ return {"model": self.model}
93
+
94
+ @property
95
+ def _llm_type(self):
96
+ return "custom_huggingface"
97
+
98
+
99
+
100
  @contextmanager
101
  def get_db_connection():
102
  """Context manager for database connections"""
 
150
  raise ValueError("HUGGINGFACE_API_TOKEN not found in environment variables")
151
  print("API token found")
152
 
153
+ # Initialize with custom LLM
154
+ llm = CustomHuggingFaceInference(token=api_token)
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # Test the LLM
157
  print("Testing LLM with a simple prompt...")