Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -101,19 +101,28 @@ def load_model():
|
|
| 101 |
|
| 102 |
model = load_model()
|
| 103 |
|
| 104 |
-
# Define
|
| 105 |
-
chars = sorted(list(set("
|
| 106 |
stoi = {ch: i for i, ch in enumerate(chars)}
|
| 107 |
itos = {i: ch for i, ch in enumerate(chars)}
|
| 108 |
-
encode = lambda s: [stoi[c] for c in s]
|
| 109 |
decode = lambda l: ''.join([itos[i] for i in l])
|
| 110 |
|
| 111 |
# Function to generate text using the model
|
| 112 |
def generate_text(prompt):
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
# Create a Gradio interface
|
| 119 |
interface = gr.Interface(
|
|
@@ -125,4 +134,4 @@ interface = gr.Interface(
|
|
| 125 |
)
|
| 126 |
|
| 127 |
# Launch the interface
|
| 128 |
-
interface.launch()
|
|
|
|
| 101 |
|
| 102 |
model = load_model()
|
| 103 |
|
| 104 |
+
# Define a comprehensive character set based on training data
|
| 105 |
+
chars = sorted(list(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?-:;'\"\n")))
|
| 106 |
stoi = {ch: i for i, ch in enumerate(chars)}
|
| 107 |
itos = {i: ch for i, ch in enumerate(chars)}
|
| 108 |
+
encode = lambda s: [stoi[c] for c in s if c in stoi] # Ensures only known characters are encoded
|
| 109 |
decode = lambda l: ''.join([itos[i] for i in l])
|
| 110 |
|
| 111 |
# Function to generate text using the model
|
| 112 |
def generate_text(prompt):
|
| 113 |
+
try:
|
| 114 |
+
print(f"Received prompt: {prompt}")
|
| 115 |
+
context = torch.tensor([encode(prompt)], dtype=torch.long)
|
| 116 |
+
print(f"Encoded prompt: {context}")
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
generated = model.generate(context, max_new_tokens=250) # Adjust as needed
|
| 119 |
+
print(f"Generated tensor: {generated}")
|
| 120 |
+
result = decode(generated[0].tolist())
|
| 121 |
+
print(f"Decoded result: {result}")
|
| 122 |
+
return result
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"Error during generation: {e}")
|
| 125 |
+
return f"Error: {str(e)}"
|
| 126 |
|
| 127 |
# Create a Gradio interface
|
| 128 |
interface = gr.Interface(
|
|
|
|
| 134 |
)
|
| 135 |
|
| 136 |
# Launch the interface
|
| 137 |
+
interface.launch(share=True)
|